Ejemplo n.º 1
0
    def affine_optimization(self):
        """
        call affine optimization registration in mermaid
        :return: warped image, transformation map, affine parameter, loss(None)
        """
        self.si = SI.RegisterImagePair()
        extra_info = pars.ParameterDict()
        extra_info['pair_name'] = self.fname_list
        af_sigma = self.opt_mermaid['affine']['sigma']
        self.si.opt = None
        self.si.set_initial_map(None)
        if self.saved_affine_setting_path is None:
            self.saved_affine_setting_path = self.save_setting(
                self.setting_for_mermaid_affine, self.record_path,
                'affine_setting.json')

        cur_affine_json_saving_path = (os.path.join(
            self.record_path, 'cur_settings_affine.json'),
                                       os.path.join(
                                           self.record_path,
                                           'cur_settings_affine_comment.json'))
        self.si.register_images(
            self.moving,
            self.target,
            self.spacing,
            extra_info=extra_info,
            LSource=self.l_moving,
            LTarget=self.l_target,
            visualize_step=None,
            use_multi_scale=True,
            rel_ftol=0,
            similarity_measure_sigma=af_sigma,
            json_config_out_filename=
            cur_affine_json_saving_path,  #########################################
            params=self.saved_affine_setting_path
        )  #'../easyreg/cur_settings_affine_tmp.json'

        self.output = self.si.get_warped_image()
        self.phi = self.si.opt.optimizer.ssOpt.get_map()
        self.phi = self.phi.detach().clone()
        # for i in range(self.dim):
        #     self.phi[:, i, ...] = self.phi[:, i, ...] / ((self.input_img_sz[i] - 1) * self.spacing[i])

        Ab = self.si.opt.optimizer.ssOpt.model.Ab

        if self.compute_inverse_map:
            inv_Ab = py_utils.get_inverse_affine_param(Ab.detach())
            identity_map = py_utils.identity_map_multiN(
                [1, 1] + self.input_img_sz, self.spacing)
            self.inversed_map = py_utils.apply_affine_transform_to_map_multiNC(
                inv_Ab,
                torch.Tensor(identity_map).cuda())  ##########################3
            self.inversed_map = self.inversed_map.detach()
        self.afimg_or_afparam = Ab
        save_affine_param_with_easyreg_custom(self.afimg_or_afparam,
                                              self.record_path,
                                              self.fname_list,
                                              affine_compute_from_mermaid=True)
        return self.output.detach_(), self.phi.detach_(
        ), self.afimg_or_afparam.detach_(), None
Ejemplo n.º 2
0
def nonp_optimization(si, moving,target,spacing,fname,l_moving=None,l_target=None, init_weight= None,expr_folder= None,mermaid_setting_path=None):
    affine_map = None
    if si is not None:
        affine_map = si.opt.optimizer.ssOpt.get_map()

    si =  SI.RegisterImagePair()
    extra_info = setting_visual_saving(expr_folder,fname)
    si.opt = None
    if affine_map is not None:
        si.set_initial_map(affine_map.detach())
    if init_weight is not None:
        si.set_weight_map(init_weight.detach(),freeze_weight=True)

    si.register_images(moving, target, spacing, extra_info=extra_info, LSource=l_moving,
                            LTarget=l_target,
                            map_low_res_factor=0.5,
                            visualize_step=30,
                            optimizer_name='lbfgs_ls',
                            use_multi_scale=True,
                            rel_ftol=0,
                            similarity_measure_type='lncc',
                            params=mermaid_setting_path)
    output = si.get_warped_image()
    phi = si.opt.optimizer.ssOpt.get_map()
    model_param = si.get_model_parameters()
    if len(model_param)==2:
        m, weight_map = model_param['m'], model_param['local_weights']
        return output.detach_(), phi.detach_(), m.detach(), weight_map.detach()
    else:
        m = model_param['m']
        return output.detach_(), phi.detach_(), m.detach(), None
Ejemplo n.º 3
0
 def initialize(self, opt):
     """
     :param opt: ParameterDict, task settings
     :return:
     """
     MermaidBase.initialize(self, opt)
     method_name = opt['tsk_set']['method_name']
     if method_name == 'affine':
         self.affine_on = True
         self.nonp_on = False
     elif method_name == 'nonp':
         self.affine_on = True
         self.nonp_on = True
     elif method_name == 'nonp_only':
         self.affine_on = False
         self.nonp_on = True
     self.si = SI.RegisterImagePair()
     self.opt_optim = opt['tsk_set']['optim']
     self.compute_inverse_map = opt['tsk_set']['reg'][(
         'compute_inverse_map', False,
         "compute the inverse transformation map")]
     self.opt_mermaid = self.opt['tsk_set']['reg']['mermaid_iter']
     self.use_init_weight = self.opt_mermaid[(
         'use_init_weight', False,
         'whether to use init weight for RDMM registration')]
     self.init_weight = None
     self.setting_for_mermaid_affine = self.opt_mermaid[(
         'mermaid_affine_json', '',
         'the json path for the setting for mermaid affine')]
     self.setting_for_mermaid_nonp = self.opt_mermaid[(
         'mermaid_nonp_json', '',
         'the json path for the setting for mermaid non-parametric')]
     nonp_settings = pars.ParameterDict()
     nonp_settings.load_JSON(self.setting_for_mermaid_nonp)
     self.nonp_model_name = nonp_settings['model']['registration_model'][
         'type']
     self.weights_for_fg = self.opt_mermaid[('weights_for_fg', [
         0, 0, 0, 0, 1.
     ], 'regularizer weight for the foregound area, this should be got from the mermaid_json file'
                                             )]
     self.weights_for_bg = self.opt_mermaid[('weights_for_bg', [
         0, 0, 0, 0, 1.
     ], 'regularizer weight for the background area')]
     self.saved_mermaid_setting_path = None
     self.saved_affine_setting_path = None
     self.inversed_map = None
     self.use_01 = True
Ejemplo n.º 4
0
    def __init__(self, reg_models):
        """
        initialize with a sequence of registration mode
        :param reg_models: a list of tuples (model_name:string, model_setting: json_file_name(string) or an ParameterDict object)
        """
        self.si_ = SI.RegisterImagePair()
        self.model_0_name, self.model_0_setting = reg_models[0]
        self.model_1_name, self.model_1_setting = reg_models[1]
        self.im_io = FIO.ImageIO()

        self.target_image_np = None
        self.moving_image_np = None
        self.target_mask = None
        self.moving_mask = None

        self.Ab = None
        self.map = None
        self.inverse_map = None
Ejemplo n.º 5
0
def affine_optimization(moving,target,spacing,fname_list,l_moving=None,l_target=None):
    si = SI.RegisterImagePair()
    extra_info={}
    extra_info['pair_name'] = fname_list
    si.opt = None
    si.set_initial_map(None)
    si.register_images(moving, target, spacing,extra_info=extra_info,LSource=l_moving,LTarget=l_target,
                            model_name='affine_map',
                            map_low_res_factor=1.0,
                            nr_of_iterations=100,
                            visualize_step=None,
                            optimizer_name='sgd',
                            use_multi_scale=True,
                            rel_ftol=0,
                            similarity_measure_type='lncc',
                            similarity_measure_sigma=1.,
                            params ='../mermaid/mermaid_demos/rdmm_synth_data_generation/cur_settings_affine.json')
    output = si.get_warped_image()
    phi = si.opt.optimizer.ssOpt.get_map()
    disp = si.opt.optimizer.ssOpt.model.Ab
    # phi = phi*2-1
    phi = phi.detach().clone()
    return output.detach_(), phi.detach_(), disp.detach_(), si
Ejemplo n.º 6
0
    # create a default image size with two sample squares
    I0, I1, spacing = EG.CreateSquares(dim,add_noise_to_bg).create_image_pair(szEx, params)
else:
    # return a real image example
    I0, I1, spacing = EG.CreateRealExampleImages(dim).create_image_pair() # create a default image size with two sample squares

##################################
# Creating the registration algorithm
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# We simply instantiate a simple interface object for the registration of image pairs.
# We can then query it as to what models registration models are currently supported.
#

# create a simple interface object for pair-wise image registration
si = SI.RegisterImagePair()

# print possible model names
si.print_available_models()

################################
# Doing the registration
# ^^^^^^^^^^^^^^^^^^^^^^
#
# We are now ready to perform the registration (picking one of the registration model options printed above).
#
# Here, we use a shooting-based LDDMM algorithm which works directly with transformation maps.
# The simple interface allows setting various registration settings.
#
# Of note, we read in a parameter file (``test2d_tst.json``) to parametrize the registration algorithm and
# write out the used parameters (after the run) into the same file as well as into a file with comments explaining
def do_registration(source_images,
                    target_images,
                    model_name,
                    output_directory,
                    nr_of_epochs,
                    nr_of_iterations,
                    map_low_res_factor,
                    visualize_step,
                    json_in,
                    json_out,
                    optimize_over_deep_network=False,
                    evaluate_but_do_not_optimize_over_shared_parameters=False,
                    load_shared_parameters_from_file=None,
                    optimize_over_weights=False,
                    freeze_parameters=False,
                    start_from_previously_saved_parameters=True,
                    args_kvs=None,
                    only_run_stage0_with_unchanged_config=False):

    if load_shared_parameters_from_file is not None:
        shared_target_dir = os.path.join(output_directory, 'shared')
        if not os.path.exists(shared_target_dir):
            print('INFO: creating current shared directory {}'.format(
                shared_target_dir))
            os.makedirs(shared_target_dir)
        print(
            'INFO: copying the shared parameter file {} to the current shared parameter directory {}'
            .format(load_shared_parameters_from_file, shared_target_dir))
        shutil.copy(load_shared_parameters_from_file, shared_target_dir)

    reg = si.RegisterImagePair()

    # load the json file if it is a file and make necessary modifications
    if type(json_in) == pars.ParameterDict:
        params_in = json_in
    else:
        params_in = pars.ParameterDict()
        print('Loading settings from file: ' + json_in)
        params_in.load_JSON(json_in)

    # we need to check if  nr_of_iterations or map_low_res_factor is overwritten in the key-value arguments
    has_iterations_before = params_in.has_key(
        ['optimizer', 'single_scale', 'nr_of_iterations'])
    has_map_low_res_factor_before = params_in.has_key(
        ['model', 'deformation', 'map_low_res_factor'])

    add_key_value_pairs_to_params(params_in, args_kvs)

    has_iterations_after = params_in.has_key(
        ['optimizer', 'single_scale', 'nr_of_iterations'])
    has_map_low_res_factor_after = params_in.has_key(
        ['model', 'deformation', 'map_low_res_factor'])

    kv_set_iterations = not has_iterations_before and has_iterations_after
    kv_set_map_low_res_factor = not has_map_low_res_factor_before and has_map_low_res_factor_after

    if kv_set_iterations:
        kv_nr_of_iterations = params_in['optimizer']['single_scale'][
            'nr_of_iterations']
        print(
            'INFO: nr_of_iterations was overwritten by key-value pair: {} -> {}'
            .format(nr_of_iterations, kv_nr_of_iterations))
        nr_of_iterations = kv_nr_of_iterations

    if kv_set_map_low_res_factor:
        kv_map_low_res_factor = params_in['model']['deformation'][
            'map_low_res_factor']
        print(
            'INFO: map_low_res_factor was overwritten by key-value pair: {} -> {}'
            .format(map_low_res_factor, kv_map_low_res_factor))
        map_low_res_factor = kv_map_low_res_factor

    if map_low_res_factor is None:
        map_low_res_factor = params_in['model']['deformation'][(
            'map_low_res_factor', 1.0, 'low res factor for the map')]

    params_in['optimizer']['batch_settings']['nr_of_epochs'] = nr_of_epochs
    params_in['optimizer']['batch_settings'][
        'parameter_output_dir'] = output_directory
    params_in['optimizer']['batch_settings'][
        'start_from_previously_saved_parameters'] = start_from_previously_saved_parameters

    params_in['model']['registration_model']['forward_model']['smoother'][
        'type'] = 'learned_multiGaussianCombination'
    params_in['model']['registration_model']['forward_model']['smoother'][
        'start_optimize_over_smoother_parameters_at_iteration'] = 0
    params_in['model']['registration_model']['forward_model']['smoother'][
        'freeze_parameters'] = freeze_parameters

    if not only_run_stage0_with_unchanged_config:
        # we use the setting of the stage
        params_in['model']['registration_model']['forward_model']['smoother'][
            'optimize_over_deep_network'] = optimize_over_deep_network
        params_in['model']['registration_model']['forward_model']['smoother'][
            'evaluate_but_do_not_optimize_over_shared_registration_parameters'] = evaluate_but_do_not_optimize_over_shared_parameters
        params_in['model']['registration_model']['forward_model']['smoother'][
            'optimize_over_smoother_stds'] = False
        params_in['model']['registration_model']['forward_model']['smoother'][
            'optimize_over_smoother_weights'] = optimize_over_weights

        if load_shared_parameters_from_file is not None:
            params_in['model']['registration_model']['forward_model'][
                'smoother'][
                    'load_dnn_parameters_from_this_file'] = load_shared_parameters_from_file

    else:
        print('\n\n')
        print('-------------------------------------')
        print(
            'INFO: Overwriting the stage settings; using {:s} without modifications. Use this only for DEBUGGING!'
            .format(json_in))
        print('-------------------------------------')
        print('\n\n')

    spacing = None
    reg.register_images(source_images,
                        target_images,
                        spacing,
                        model_name=model_name,
                        nr_of_iterations=nr_of_iterations,
                        map_low_res_factor=map_low_res_factor,
                        visualize_step=visualize_step,
                        json_config_out_filename=json_out,
                        use_batch_optimization=True,
                        params=params_in)
Ejemplo n.º 8
0
    def nonp_optimization(self):
        """
        call non-parametric image registration in mermaid
        if the affine registration is performed first, the affine transformation map would be taken as the initial map
        if the init weight on mutli-gaussian regularizer are set, the initial weight map would be computed from the label map, make sure the model called support spatial variant regularizer

        :return: warped image, transformation map, affined image, loss(None)
        """
        affine_map = None
        if self.affine_on:
            affine_map = self.si.opt.optimizer.ssOpt.get_map()

        self.si = SI.RegisterImagePair()
        extra_info = pars.ParameterDict()
        extra_info['pair_name'] = self.fname_list
        self.si.opt = None
        if affine_map is not None:
            self.si.set_initial_map(affine_map.detach(), self.inversed_map)

        if self.use_init_weight:
            init_weight = get_init_weight_from_label_map(
                self.l_moving, self.spacing, self.weights_for_bg,
                self.weights_for_fg)
            init_weight = py_utils.compute_warped_image_multiNC(
                init_weight,
                affine_map,
                self.spacing,
                spline_order=1,
                zero_boundary=False)
            self.si.set_weight_map(init_weight.detach(), freeze_weight=True)

        if self.saved_mermaid_setting_path is None:
            self.saved_mermaid_setting_path = self.save_setting(
                self.setting_for_mermaid_nonp, self.record_path,
                "nonp_setting.json")
        cur_mermaid_json_saving_path = (os.path.join(self.record_path,
                                                     'cur_settings_nonp.json'),
                                        os.path.join(
                                            self.record_path,
                                            'cur_settings_nonp_comment.json'))
        self.si.register_images(
            self.moving,
            self.target,
            self.spacing,
            extra_info=extra_info,
            LSource=self.l_moving,
            LTarget=self.l_target,
            visualize_step=None,
            use_multi_scale=True,
            rel_ftol=0,
            compute_inverse_map=self.compute_inverse_map,
            json_config_out_filename=cur_mermaid_json_saving_path,
            params=self.saved_mermaid_setting_path
        )  #'../mermaid_settings/cur_settings_svf_dipr.json'
        self.afimg_or_afparam = self.output  # here return the affine image
        self.output = self.si.get_warped_image()
        self.phi = self.si.opt.optimizer.ssOpt.get_map()
        # for i in range(self.dim):
        #     self.phi[:,i,...] = self.phi[:,i,...]/ ((self.input_img_sz[i]-1)*self.spacing[i])

        if self.compute_inverse_map:
            self.inversed_map = self.si.get_inverse_map().detach()
        return self.output.detach_(), self.phi.detach_(
        ), self.afimg_or_afparam.detach_(
        ) if self.afimg_or_afparam is not None else None, None
Ejemplo n.º 9
0
def build_atlas(images, nr_of_cycles, warped_images, temp_folder, visualize):
    si = SI.RegisterImagePair()
    im_io = FIO.ImageIO()

    # compute first average image
    Iavg, sp = compute_average_image(images)
    Iavg = Iavg.data

    if visualize:
        plt.imshow(AdaptVal(Iavg[0, 0, ...]).detach().cpu().numpy(),
                   cmap='gray')
        plt.title('Initial average based on ' + str(len(images)) + ' images')
        plt.colorbar()
        plt.show()

    # initialize list to save model parameters in between cycles
    mp = []

    # register all images to the average image and while doing so compute a new average image
    for c in range(nr_of_cycles):
        print('Starting cycle ' + str(c + 1) + '/' + str(nr_of_cycles))
        for i, im_name in enumerate(images):
            print('Registering image ' + str(i) + '/' + str(len(images)))
            Ic, hdrc, spacing, _ = im_io.read_to_nc_format(filename=im_name)

            # set former model parameters if available
            if c != 0:
                si.set_model_parameters(mp[i])

            # register current image to average image
            si.register_images(Ic,
                               AdaptVal(Iavg).detach().cpu().numpy(),
                               spacing,
                               model_name='svf_scalar_momentum_map',
                               map_low_res_factor=0.5,
                               nr_of_iterations=5,
                               visualize_step=None,
                               similarity_measure_sigma=0.5)
            wi = si.get_warped_image()

            # save current model parametrs for the next circle
            if c == 0:
                mp.append(si.get_model_parameters())
            elif c != nr_of_cycles - 1:
                mp[i] = si.get_model_parameters()

            if c == nr_of_cycles - 1:  # last time this is run, so let's save the image
                current_filename = warped_images + '/atlas_reg_Image' + str(
                    i + 1).zfill(4) + '.nrrd'
                print("writing image " + str(i + 1))
                im_io.write(current_filename, wi, hdrc)

            if i == 0:
                newAvg = wi.data
            else:
                newAvg += wi.data

        Iavg = newAvg / len(images)

        if visualize:
            plt.imshow(AdaptVal(Iavg[0, 0, ...]).detach().cpu().numpy(),
                       cmap='gray')
            plt.title('Average ' + str(c + 1) + '/' + str(nr_of_cycles))
            plt.colorbar()
            plt.show()
    return Iavg