コード例 #1
0
def create_dataset(dataset_filenames, visit_ages, subject_ids,
                   template_specifications):
    """
    Creates a longitudinal dataset object from xml parameters. 
    """

    deformable_objects_dataset = []
    for i in range(len(dataset_filenames)):
        deformable_objects_subject = []
        for j in range(len(dataset_filenames[i])):
            deformable_objects_visit = DeformableMultiObject()
            for object_id in template_specifications.keys():
                if object_id not in dataset_filenames[i][j]:
                    raise RuntimeError('The template object with id ' +
                                       object_id +
                                       ' is not found for the visit ' +
                                       str(j) + ' of subject ' + str(i) +
                                       '. Check the dataset xml.')
                else:
                    objectType = template_specifications[object_id][
                        'deformable_object_type']
                    reader = DeformableObjectReader()
                    deformable_objects_visit.object_list.append(
                        reader.create_object(
                            dataset_filenames[i][j][object_id], objectType))
            deformable_objects_visit.update()
            deformable_objects_subject.append(deformable_objects_visit)
        deformable_objects_dataset.append(deformable_objects_subject)
    longitudinal_dataset = LongitudinalDataset()
    longitudinal_dataset.times = visit_ages
    longitudinal_dataset.subject_ids = subject_ids
    longitudinal_dataset.deformable_objects = deformable_objects_dataset
    longitudinal_dataset.update()

    return longitudinal_dataset
コード例 #2
0
    def __init__(self):
        AbstractStatisticalModel.__init__(self)

        self.template = DeformableMultiObject()
        self.objects_name = []
        self.objects_name_extension = []
        self.objects_noise_variance = []

        self.multi_object_attachment = MultiObjectAttachment()
        self.exponential = Exponential()

        self.use_sobolev_gradient = True
        self.smoothing_kernel_width = None

        self.initial_cp_spacing = None
        self.number_of_subjects = None
        self.number_of_objects = None
        self.number_of_control_points = None
        self.bounding_box = None

        # Dictionary of numpy arrays.
        self.fixed_effects['template_data'] = None
        self.fixed_effects['control_points'] = None
        self.fixed_effects['momenta'] = None

        self.freeze_template = False
        self.freeze_control_points = False
        self.freeze_momenta = False
コード例 #3
0
    def __init__(self):
        AbstractStatisticalModel.__init__(self)

        self.template = DeformableMultiObject()
        self.objects_name = []
        self.objects_name_extension = []
        self.objects_noise_dimension = []

        self.multi_object_attachment = None
        self.exponential = Exponential()

        self.use_sobolev_gradient = True
        self.smoothing_kernel_width = None

        self.initial_cp_spacing = None
        self.number_of_objects = None
        self.number_of_control_points = None
        self.bounding_box = None

        # Dictionary of numpy arrays.
        self.fixed_effects['template_data'] = None
        self.fixed_effects['control_points'] = None
        self.fixed_effects['covariance_momenta_inverse'] = None
        self.fixed_effects['noise_variance'] = None

        # Dictionary of probability distributions.
        self.priors['covariance_momenta'] = InverseWishartDistribution()
        self.priors['noise_variance'] = MultiScalarInverseWishartDistribution()

        # Dictionary of probability distributions.
        self.individual_random_effects['momenta'] = NormalDistribution()

        self.freeze_template = False
        self.freeze_control_points = False
コード例 #4
0
    def __init__(self, kernel_type, kernel_device='CPU', use_cuda=False, data_type='landmark', data_size='small'):

        np.random.seed(42)
        kernel_width = 10.

        if use_cuda:
            self.tensor_scalar_type = torch.cuda.FloatTensor
        else:
            self.tensor_scalar_type = torch.FloatTensor

        self.exponential = Exponential(kernel=kernel_factory.factory(kernel_type, kernel_width, kernel_device),
                                       number_of_time_points=11, use_rk2_for_flow=False, use_rk2_for_shoot=False)

        if data_type.lower() == 'landmark':
            reader = DeformableObjectReader()
            if data_size == 'small':
                surface_mesh = reader.create_object(path_to_small_surface_mesh_1, 'SurfaceMesh')
                self.control_points = create_regular_grid_of_points(surface_mesh.bounding_box, kernel_width, surface_mesh.dimension)
            elif data_size == 'large':
                surface_mesh = reader.create_object(path_to_large_surface_mesh_1, 'SurfaceMesh')
                self.control_points = create_regular_grid_of_points(surface_mesh.bounding_box, kernel_width, surface_mesh.dimension)
            else:
                connectivity = np.array(list(itertools.combinations(range(100), 3))[:int(data_size)])  # up to ~16k.
                surface_mesh = SurfaceMesh(3)
                surface_mesh.set_points(np.random.randn(np.max(connectivity) + 1, surface_mesh.dimension))
                surface_mesh.set_connectivity(connectivity)
                surface_mesh.update()
                self.control_points = np.random.randn(int(data_size) // 10, 3)
            # self.template.object_list.append(surface_mesh)
            self.template = DeformableMultiObject([surface_mesh])

        elif data_type.lower() == 'image':
            image = Image(3)
            image.set_intensities(np.random.randn(int(data_size), int(data_size), int(data_size)))
            image.set_affine(np.eye(4))
            image.downsampling_factor = 5.
            image.update()
            self.control_points = create_regular_grid_of_points(image.bounding_box, kernel_width, image.dimension)
            self.control_points = remove_useless_control_points(self.control_points, image, kernel_width)
            # self.template.object_list.append(image)
            self.template = DeformableMultiObject([image])

        else:
            raise RuntimeError('Unknown data_type argument. Choose between "landmark" or "image".')

        # self.template.update()
        self.momenta = np.random.randn(*self.control_points.shape)
コード例 #5
0
def compute_distance_squared(path_to_mesh_1,
                             path_to_mesh_2,
                             deformable_object_type,
                             attachment_type,
                             kernel_width=None):
    reader = DeformableObjectReader()
    object_1 = reader.create_object(path_to_mesh_1,
                                    deformable_object_type.lower())
    object_2 = reader.create_object(path_to_mesh_2,
                                    deformable_object_type.lower())

    multi_object_1 = DeformableMultiObject([object_1])
    multi_object_2 = DeformableMultiObject([object_2])
    multi_object_attachment = MultiObjectAttachment(
        [attachment_type], [kernel_factory.factory('torch', kernel_width)])

    return multi_object_attachment.compute_distances(
        {
            key: torch.from_numpy(value)
            for key, value in multi_object_1.get_points().items()
        }, multi_object_1, multi_object_2).data.cpu().numpy()
コード例 #6
0
ファイル: run_shooting.py プロジェクト: EuroPOND/deformetrica
def run_shooting(xml_parameters):
    
    print('[ run_shooting function ]')
    print('')
    
    """
    Create the template object
    """
    
    t_list, t_name, t_name_extension, t_noise_variance, multi_object_attachment = \
        create_template_metadata(xml_parameters.template_specifications)
    
    print("Object list:", t_list)
    
    template = DeformableMultiObject()
    template.object_list = t_list
    template.update()
    
    """
    Reading Control points and momenta
    """
    
    # if not (os.path.exists(Settings().output_dir)): Settings().output_dir
    
    
    if not xml_parameters.initial_control_points is None:
        control_points = read_2D_array(xml_parameters.initial_control_points)
    else:
        raise ArgumentError('Please specify a path to control points to perform a shooting')
    
    if not xml_parameters.initial_momenta is None:
        momenta = read_3D_array(xml_parameters.initial_momenta)
    else:
        raise ArgumentError('Please specify a path to momenta to perform a shooting')
    
    template_data_numpy = template.get_points()
    template_data_torch = Variable(torch.from_numpy(template_data_numpy))
    
    momenta_torch = Variable(torch.from_numpy(momenta))
    control_points_torch = Variable(torch.from_numpy(control_points))
    
    exp = Exponential()
    exp.set_initial_control_points(control_points_torch)
    exp.set_initial_template_data(template_data_torch)
    exp.number_of_time_points = 10
    exp.kernel = kernel_factory.factory(xml_parameters.deformation_kernel_type, xml_parameters.deformation_kernel_width)
    exp.set_use_rk2(xml_parameters.use_rk2)
    
    for i in range(len(momenta_torch)):
        exp.set_initial_momenta(momenta_torch[i])
        exp.update()
        deformedPoints = exp.get_template_data()
        names = [elt + "_"+ str(i) for elt in t_name]
        exp.write_flow(names, t_name_extension, template)
        exp.write_control_points_and_momenta_flow("Shooting_"+str(i))
コード例 #7
0
class DeterministicAtlas(AbstractStatisticalModel):
    """
    Deterministic atlas object class.

    """

    ####################################################################################################################
    ### Constructor:
    ####################################################################################################################

    def __init__(self):
        AbstractStatisticalModel.__init__(self)

        self.template = DeformableMultiObject()
        self.objects_name = []
        self.objects_name_extension = []
        self.objects_noise_variance = []

        self.multi_object_attachment = MultiObjectAttachment()
        self.exponential = Exponential()

        self.use_sobolev_gradient = True
        self.smoothing_kernel_width = None

        self.initial_cp_spacing = None
        self.number_of_subjects = None
        self.number_of_objects = None
        self.number_of_control_points = None
        self.bounding_box = None

        # Dictionary of numpy arrays.
        self.fixed_effects['template_data'] = None
        self.fixed_effects['control_points'] = None
        self.fixed_effects['momenta'] = None

        self.freeze_template = False
        self.freeze_control_points = False
        self.freeze_momenta = False

    ####################################################################################################################
    ### Encapsulation methods:
    ####################################################################################################################

    # Template data ----------------------------------------------------------------------------------------------------
    def get_template_data(self):
        return self.fixed_effects['template_data']

    def set_template_data(self, td):
        self.fixed_effects['template_data'] = td
        self.template.set_data(td)

    # Control points ---------------------------------------------------------------------------------------------------
    def get_control_points(self):
        return self.fixed_effects['control_points']

    def set_control_points(self, cp):
        self.fixed_effects['control_points'] = cp
        self.number_of_control_points = len(cp)

    # Momenta ----------------------------------------------------------------------------------------------------------
    def get_momenta(self):
        return self.fixed_effects['momenta']

    def set_momenta(self, mom):
        self.fixed_effects['momenta'] = mom

    # Full fixed effects -----------------------------------------------------------------------------------------------
    def get_fixed_effects(self):
        out = {}
        if not self.freeze_template:
            for key, value in self.fixed_effects['template_data'].items():
                out[key] = value
        if not self.freeze_control_points:
            out['control_points'] = self.fixed_effects['control_points']
        if not self.freeze_momenta:
            out['momenta'] = self.fixed_effects['momenta']
        return out

    def set_fixed_effects(self, fixed_effects):
        if not self.freeze_template:
            template_data = {
                key: fixed_effects[key]
                for key in self.fixed_effects['template_data'].keys()
            }
            self.set_template_data(template_data)
        if not self.freeze_control_points:
            self.set_control_points(fixed_effects['control_points'])
        if not self.freeze_momenta:
            self.set_momenta(fixed_effects['momenta'])

    ####################################################################################################################
    ### Public methods:
    ####################################################################################################################

    def update(self):
        """
        Final initialization steps.
        """

        self.template.update()
        self.number_of_objects = len(self.template.object_list)
        self.bounding_box = self.template.bounding_box

        self.set_template_data(self.template.get_data())
        if self.fixed_effects['control_points'] is None:
            self._initialize_control_points()
        else:
            self._initialize_bounding_box()
        if self.fixed_effects['momenta'] is None: self._initialize_momenta()

    # Compute the functional. Numpy input/outputs.
    def compute_log_likelihood(self,
                               dataset,
                               population_RER,
                               individual_RER,
                               mode='complete',
                               with_grad=False):
        """
        Compute the log-likelihood of the dataset, given parameters fixed_effects and random effects realizations
        population_RER and indRER.

        :param dataset: LongitudinalDataset instance
        :param fixed_effects: Dictionary of fixed effects.
        :param population_RER: Dictionary of population random effects realizations.
        :param individual_RER: Dictionary of individual random effects realizations.
        :param mode: Indicates which log_likelihood should be computed, between 'complete', 'model', and 'class2'.
        :param with_grad: Flag that indicates wether the gradient should be returned as well.
        :return:
        """

        if False and Settings().number_of_threads > 1:
            targets = [target[0] for target in dataset.deformable_objects]
            args = [
                (i, Settings().serialize(), self.template,
                 self.fixed_effects['template_data'],
                 self.fixed_effects['control_points'],
                 self.fixed_effects['momenta'][i], self.freeze_template,
                 self.freeze_control_points, self.freeze_momenta, targets[i],
                 self.multi_object_attachment, self.objects_noise_variance,
                 self.exponential.light_copy(), with_grad,
                 self.use_sobolev_gradient, self.smoothing_kernel_width)
                for i in range(len(targets))
            ]

            # Perform parallelized computations.
            with ThreadPoolExecutor(
                    max_workers=Settings().number_of_threads) as pool:
                results = pool.map(_subject_attachment_and_regularity, args)

            # Sum and return.
            if with_grad:
                attachment = 0.0
                regularity = 0.0

                gradient = {}
                if not self.freeze_template:
                    for key, value in self.fixed_effects[
                            'template_data'].items():
                        gradient[key] = np.zeros(value.shape)
                if not self.freeze_control_points:
                    gradient['control_points'] = np.zeros(
                        self.fixed_effects['control_points'].shape)
                if not self.freeze_momenta:
                    gradient['momenta'] = np.zeros(
                        self.fixed_effects['momenta'].shape)

                for result in results:
                    i, attachment_i, regularity_i, gradient_i = result
                    attachment += attachment_i
                    regularity += regularity_i
                    for key, value in gradient_i.items():
                        if key == 'momenta': gradient[key][i] = value
                        else: gradient[key] += value
                return attachment, regularity, gradient
            else:
                attachment = 0.0
                regularity = 0.0
                for result in results:
                    i, attachment_i, regularity_i = result
                    attachment += attachment_i
                    regularity += regularity_i
                    return attachment, regularity

        else:
            template_data, template_points, control_points, momenta = self._fixed_effects_to_torch_tensors(
                with_grad)
            return self._compute_attachment_and_regularity(
                dataset, template_data, template_points, control_points,
                momenta, with_grad)

    def initialize_template_attributes(self, template_specifications):
        """
        Sets the Template, TemplateObjectsName, TemplateObjectsNameExtension, TemplateObjectsNorm,
        TemplateObjectsNormKernelType and TemplateObjectsNormKernelWidth attributes.
        """

        t_list, t_name, t_name_extension, t_noise_variance, t_multi_object_attachment = \
            create_template_metadata(template_specifications)

        self.template.object_list = t_list
        self.objects_name = t_name
        self.objects_name_extension = t_name_extension
        self.objects_noise_variance = t_noise_variance
        self.multi_object_attachment = t_multi_object_attachment
        self.template.update()

    ####################################################################################################################
    ### Private methods:
    ####################################################################################################################

    def _compute_attachment_and_regularity(self,
                                           dataset,
                                           template_data,
                                           template_points,
                                           control_points,
                                           momenta,
                                           with_grad=False):
        """
        Core part of the ComputeLogLikelihood methods. Torch input, numpy output.
        Single-thread version.
        """
        # Initialize.
        targets = [target[0] for target in dataset.deformable_objects]

        regularity = 0.
        attachment = 0.

        # Deform.
        self.exponential.set_initial_template_points(template_points)
        self.exponential.set_initial_control_points(control_points)

        for i, target in enumerate(targets):
            self.exponential.set_initial_momenta(momenta[i])
            self.exponential.update()
            deformed_points = self.exponential.get_template_points()
            deformed_data = self.template.get_deformed_data(
                deformed_points, template_data)
            regularity -= self.exponential.get_norm_squared()
            attachment -= self.multi_object_attachment.compute_weighted_distance(
                deformed_data, self.template, target,
                self.objects_noise_variance)

        # Compute gradient.
        if with_grad:
            total = attachment + regularity
            total = attachment
            total.backward()

            gradient = {}
            if not self.freeze_template:
                if 'landmark_points' in template_data.keys():
                    if self.use_sobolev_gradient:
                        gradient['landmark_points'] = compute_sobolev_gradient(
                            template_points['landmark_points'].grad.detach(),
                            self.smoothing_kernel_width,
                            self.template).cpu().numpy()
                    else:
                        gradient['landmark_points'] = template_points[
                            'landmark_points'].grad.detach().cpu().numpy()
                if 'image_intensities' in template_data.keys():
                    gradient['image_intensities'] = template_data[
                        'image_intensities'].grad.detach().cpu().numpy()
            if not self.freeze_control_points:
                gradient['control_points'] = control_points.grad.detach().cpu(
                ).numpy()
            if not self.freeze_momenta:
                gradient['momenta'] = momenta.grad.detach().cpu().numpy()

            return attachment.detach().cpu().numpy(), regularity.detach().cpu(
            ).numpy(), gradient

        else:
            return attachment.detach().cpu().numpy(), regularity.detach().cpu(
            ).numpy()

    def _initialize_control_points(self):
        """
        Initialize the control points fixed effect.
        """
        if not Settings().dense_mode:
            control_points = create_regular_grid_of_points(
                self.bounding_box, self.initial_cp_spacing)
            for elt in self.template.object_list:
                if elt.type.lower() == 'image':
                    control_points = remove_useless_control_points(
                        control_points, elt,
                        self.exponential.get_kernel_width())
                    break
        else:
            control_points = self.template.get_points()

        # FILTERING TOO CLOSE POINTS: DISABLED FOR NOW

        # indices_to_remove = []
        # for i in range(len(control_points)):
        #     for j in range(len(control_points)):
        #         if i != j:
        #             d = np.linalg.norm(control_points[i] - control_points[j])
        #             if d < 0.1 * self.exponential.kernel.kernel_width:
        #                 indices_to_remove.append(i)
        #
        # print(len(indices_to_remove))
        #
        # indices_to_remove = list(set(indices_to_remove))
        # indices_to_keep = [elt for elt in range(len(control_points)) if elt not in indices_to_remove]
        # control_points = np.array([control_points[i] for i in indices_to_keep])

        self.set_control_points(control_points)
        self.number_of_control_points = control_points.shape[0]
        logger.info('Set of ' + str(self.number_of_control_points) +
                    ' control points defined.')

    def _initialize_momenta(self):
        """
        Initialize the momenta fixed effect.
        """

        assert (self.number_of_subjects > 0)
        momenta = np.zeros(
            (self.number_of_subjects, self.number_of_control_points,
             Settings().dimension))
        self.set_momenta(momenta)
        logger.info('Momenta initialized to zero, for ' +
                    str(self.number_of_subjects) + ' subjects.')

    def _initialize_bounding_box(self):
        """
        Initialize the bounding box. which tightly encloses all template objects and the atlas control points.
        Relevant when the control points are given by the user.
        """

        assert (self.number_of_control_points > 0)

        dimension = Settings().dimension
        control_points = self.get_control_points()

        for k in range(self.number_of_control_points):
            for d in range(dimension):
                if control_points[k, d] < self.bounding_box[d, 0]:
                    self.bounding_box[d, 0] = control_points[k, d]
                elif control_points[k, d] > self.bounding_box[d, 1]:
                    self.bounding_box[d, 1] = control_points[k, d]

    ####################################################################################################################
    ### Private utility methods:
    ####################################################################################################################

    def _fixed_effects_to_torch_tensors(self, with_grad):
        """
        Convert the fixed_effects into torch tensors.
        """
        # Template data.
        template_data = self.fixed_effects['template_data']
        template_data = {
            key: Variable(
                torch.from_numpy(value).type(Settings().tensor_scalar_type),
                requires_grad=(not self.freeze_template and with_grad))
            for key, value in template_data.items()
        }

        # Template points.
        template_points = self.template.get_points()
        template_points = {
            key: Variable(
                torch.from_numpy(value).type(Settings().tensor_scalar_type),
                requires_grad=(not self.freeze_template and with_grad))
            for key, value in template_points.items()
        }

        # Control points.
        if Settings().dense_mode:
            assert 'image_intensities' not in template_data.keys() and 'image_points' not in template_points.keys(), \
                'Dense mode not available with image data.'
            control_points = template_data
        else:
            control_points = self.fixed_effects['control_points']
            control_points = Variable(
                torch.from_numpy(control_points).type(
                    Settings().tensor_scalar_type),
                requires_grad=((not self.freeze_control_points and with_grad)
                               or self.exponential.get_kernel_type()
                               == 'keops'))
        # Momenta.
        momenta = self.fixed_effects['momenta']
        momenta = Variable(
            torch.from_numpy(momenta).type(Settings().tensor_scalar_type),
            requires_grad=(not self.freeze_momenta and with_grad))

        return template_data, template_points, control_points, momenta

    ####################################################################################################################
    ### Writing methods:
    ####################################################################################################################

    def write(self,
              dataset,
              population_RER,
              individual_RER,
              write_residuals=True):

        # Write the model predictions, and compute the residuals at the same time.
        residuals = self._write_model_predictions(
            dataset, individual_RER, compute_residuals=write_residuals)

        # Write residuals.
        if write_residuals:
            residuals_list = [[
                residuals_i_k.data.cpu().numpy()
                for residuals_i_k in residuals_i
            ] for residuals_i in residuals]
            write_2D_list(residuals_list,
                          self.name + "__EstimatedParameters__Residuals.txt")

        # Write the model parameters.
        self._write_model_parameters()

    def _write_model_predictions(self,
                                 dataset,
                                 individual_RER,
                                 compute_residuals=True):

        # Initialize.
        template_data, template_points, control_points, momenta = self._fixed_effects_to_torch_tensors(
            False)

        # Deform, write reconstructions and compute residuals.
        self.exponential.set_initial_template_points(template_points)
        self.exponential.set_initial_control_points(control_points)

        residuals = []  # List of torch 1D tensors. Individuals, objects.
        for i, subject_id in enumerate(dataset.subject_ids):
            self.exponential.set_initial_momenta(momenta[i])
            self.exponential.update()

            deformed_points = self.exponential.get_template_points()
            deformed_data = self.template.get_deformed_data(
                deformed_points, template_data)

            if compute_residuals:
                residuals.append(
                    self.multi_object_attachment.compute_distances(
                        deformed_data, self.template,
                        dataset.deformable_objects[i][0]))

            names = []
            for k, (object_name, object_extension) \
                    in enumerate(zip(self.objects_name, self.objects_name_extension)):
                name = self.name + '__Reconstruction__' + object_name + '__subject_' + subject_id + object_extension
                names.append(name)
            self.template.write(
                names, {
                    key: value.data.cpu().numpy()
                    for key, value in deformed_data.items()
                })

        return residuals

    def _write_model_parameters(self):

        # Template.
        template_names = []
        for i in range(len(self.objects_name)):
            aux = self.name + "__EstimatedParameters__Template_" + self.objects_name[
                i] + self.objects_name_extension[i]
            template_names.append(aux)
        self.template.write(template_names)

        # Control points.
        write_2D_array(self.get_control_points(),
                       self.name + "__EstimatedParameters__ControlPoints.txt")

        # Momenta.
        write_3D_array(self.get_momenta(),
                       self.name + "__EstimatedParameters__Momenta.txt")
    def __init__(self,
                 template_specifications,
                 dimension=default.dimension,
                 tensor_scalar_type=default.tensor_scalar_type,
                 tensor_integer_type=default.tensor_integer_type,
                 number_of_threads=default.number_of_threads,
                 deformation_kernel_type=default.deformation_kernel_type,
                 deformation_kernel_width=default.deformation_kernel_width,
                 deformation_kernel_device=default.deformation_kernel_device,
                 shoot_kernel_type=default.shoot_kernel_type,
                 number_of_time_points=default.number_of_time_points,
                 freeze_template=default.freeze_template,
                 use_sobolev_gradient=default.use_sobolev_gradient,
                 smoothing_kernel_width=default.smoothing_kernel_width,
                 estimate_initial_velocity=default.estimate_initial_velocity,
                 initial_velocity_weight=default.initial_velocity_weight,
                 initial_control_points=default.initial_control_points,
                 freeze_control_points=default.freeze_control_points,
                 initial_cp_spacing=default.initial_cp_spacing,
                 initial_impulse_t=None,
                 initial_velocity=None,
                 **kwargs):

        AbstractStatisticalModel.__init__(self, name='AccelerationRegression')

        # Global-like attributes.
        self.dimension = dimension
        self.tensor_scalar_type = tensor_scalar_type
        self.tensor_integer_type = tensor_integer_type
        self.number_of_threads = number_of_threads

        # Declare model structure.
        self.fixed_effects['template_data'] = None
        self.fixed_effects['control_points'] = None

        self.freeze_template = freeze_template
        self.freeze_control_points = freeze_control_points

        # Deformation.
        self.acceleration_path = AccelerationPath(
            kernel=kernel_factory.factory(deformation_kernel_type,
                                          deformation_kernel_width,
                                          device=deformation_kernel_device),
            shoot_kernel_type=shoot_kernel_type,
            number_of_time_points=number_of_time_points)

        # Template.
        (object_list, self.objects_name, self.objects_name_extension,
         self.objects_noise_variance,
         self.multi_object_attachment) = create_template_metadata(
             template_specifications, self.dimension)

        self.template = DeformableMultiObject(object_list)
        self.template.update()

        template_data = self.template.get_data()

        # Set up the gompertz images A, B, and C
        intensities = template_data['image_intensities']
        self.fixed_effects['A'] = np.zeros(intensities.shape)
        self.fixed_effects['B'] = np.zeros(intensities.shape)
        self.fixed_effects['C'] = np.zeros(intensities.shape)

        self.number_of_objects = len(self.template.object_list)

        self.use_sobolev_gradient = use_sobolev_gradient
        self.smoothing_kernel_width = smoothing_kernel_width
        if self.use_sobolev_gradient:
            self.sobolev_kernel = kernel_factory.factory(
                deformation_kernel_type,
                smoothing_kernel_width,
                device=deformation_kernel_device)

        # Template data.
        self.fixed_effects['template_data'] = self.template.get_data()

        # Control points.
        self.fixed_effects['control_points'] = initialize_control_points(
            initial_control_points, self.template, initial_cp_spacing,
            deformation_kernel_width, self.dimension, False)

        self.estimate_initial_velocity = estimate_initial_velocity
        self.initial_velocity_weight = initial_velocity_weight

        self.number_of_control_points = len(
            self.fixed_effects['control_points'])
        self.number_of_time_points = number_of_time_points

        # Impulse
        self.fixed_effects['impulse_t'] = initialize_impulse(
            initial_impulse_t, self.number_of_time_points,
            self.number_of_control_points, self.dimension)
        if (self.estimate_initial_velocity):
            self.fixed_effects[
                'initial_velocity'] = initialize_initial_velocity(
                    initial_velocity, self.number_of_control_points,
                    self.dimension)
class AccelerationGompertzRegression(AbstractStatisticalModel):
    """
    Acceleration regression object class with gompertz intensity model change
    """

    ####################################################################################################################
    ### Constructor:
    ####################################################################################################################

    def __init__(self,
                 template_specifications,
                 dimension=default.dimension,
                 tensor_scalar_type=default.tensor_scalar_type,
                 tensor_integer_type=default.tensor_integer_type,
                 number_of_threads=default.number_of_threads,
                 deformation_kernel_type=default.deformation_kernel_type,
                 deformation_kernel_width=default.deformation_kernel_width,
                 deformation_kernel_device=default.deformation_kernel_device,
                 shoot_kernel_type=default.shoot_kernel_type,
                 number_of_time_points=default.number_of_time_points,
                 freeze_template=default.freeze_template,
                 use_sobolev_gradient=default.use_sobolev_gradient,
                 smoothing_kernel_width=default.smoothing_kernel_width,
                 estimate_initial_velocity=default.estimate_initial_velocity,
                 initial_velocity_weight=default.initial_velocity_weight,
                 initial_control_points=default.initial_control_points,
                 freeze_control_points=default.freeze_control_points,
                 initial_cp_spacing=default.initial_cp_spacing,
                 initial_impulse_t=None,
                 initial_velocity=None,
                 **kwargs):

        AbstractStatisticalModel.__init__(self, name='AccelerationRegression')

        # Global-like attributes.
        self.dimension = dimension
        self.tensor_scalar_type = tensor_scalar_type
        self.tensor_integer_type = tensor_integer_type
        self.number_of_threads = number_of_threads

        # Declare model structure.
        self.fixed_effects['template_data'] = None
        self.fixed_effects['control_points'] = None

        self.freeze_template = freeze_template
        self.freeze_control_points = freeze_control_points

        # Deformation.
        self.acceleration_path = AccelerationPath(
            kernel=kernel_factory.factory(deformation_kernel_type,
                                          deformation_kernel_width,
                                          device=deformation_kernel_device),
            shoot_kernel_type=shoot_kernel_type,
            number_of_time_points=number_of_time_points)

        # Template.
        (object_list, self.objects_name, self.objects_name_extension,
         self.objects_noise_variance,
         self.multi_object_attachment) = create_template_metadata(
             template_specifications, self.dimension)

        self.template = DeformableMultiObject(object_list)
        self.template.update()

        template_data = self.template.get_data()

        # Set up the gompertz images A, B, and C
        intensities = template_data['image_intensities']
        self.fixed_effects['A'] = np.zeros(intensities.shape)
        self.fixed_effects['B'] = np.zeros(intensities.shape)
        self.fixed_effects['C'] = np.zeros(intensities.shape)

        self.number_of_objects = len(self.template.object_list)

        self.use_sobolev_gradient = use_sobolev_gradient
        self.smoothing_kernel_width = smoothing_kernel_width
        if self.use_sobolev_gradient:
            self.sobolev_kernel = kernel_factory.factory(
                deformation_kernel_type,
                smoothing_kernel_width,
                device=deformation_kernel_device)

        # Template data.
        self.fixed_effects['template_data'] = self.template.get_data()

        # Control points.
        self.fixed_effects['control_points'] = initialize_control_points(
            initial_control_points, self.template, initial_cp_spacing,
            deformation_kernel_width, self.dimension, False)

        self.estimate_initial_velocity = estimate_initial_velocity
        self.initial_velocity_weight = initial_velocity_weight

        self.number_of_control_points = len(
            self.fixed_effects['control_points'])
        self.number_of_time_points = number_of_time_points

        # Impulse
        self.fixed_effects['impulse_t'] = initialize_impulse(
            initial_impulse_t, self.number_of_time_points,
            self.number_of_control_points, self.dimension)
        if (self.estimate_initial_velocity):
            self.fixed_effects[
                'initial_velocity'] = initialize_initial_velocity(
                    initial_velocity, self.number_of_control_points,
                    self.dimension)

    def initialize_noise_variance(self, dataset):
        if np.min(self.objects_noise_variance) < 0:
            template_data, template_points, control_points, impulse_t, initial_velocity = self._fixed_effects_to_torch_tensors(
                False)
            target_times = dataset.times[0]
            target_objects = dataset.deformable_objects[0]

            self.acceleration_path.set_tmin(min(target_times))
            self.acceleration_path.set_tmax(max(target_times))
            self.acceleration_path.set_template_points_tmin(template_points)
            self.acceleration_path.set_control_points_tmin(control_points)
            self.acceleration_path.set_impulse_t(impulse_t)
            self.acceleration_path.set_initial_velocity(initial_velocity)
            self.acceleration_path.update()

            residuals = np.zeros((self.number_of_objects, ))
            for (time, target) in zip(target_times, target_objects):
                deformed_points = self.acceleration_path.get_template_points(
                    time)
                deformed_data = self.template.get_deformed_data(
                    deformed_points, template_data)
                residuals += self.multi_object_attachment.compute_distances(
                    deformed_data, self.template, target).data.numpy()

            # Initialize the noise variance hyper-parameter as a 1/100th of the initial residual.
            for k, obj in enumerate(self.objects_name):
                if self.objects_noise_variance[k] < 0:
                    nv = 0.01 * residuals[k] / float(len(target_times))
                    self.objects_noise_variance[k] = nv
                    print('>> Automatically chosen noise std: %.4f [ %s ]' %
                          (math.sqrt(nv), obj))

    def set_asymptote_image(self, A):
        print("SETTING ASYMPTOTE IMAGE")
        self.A = A

    ####################################################################################################################
    ### Encapsulation methods:
    ####################################################################################################################

    # Template data ----------------------------------------------------------------------------------------------------
    def get_template_data(self):
        return self.fixed_effects['template_data']

    def set_template_data(self, td):
        self.fixed_effects['template_data'] = td
        self.template.set_data(td)

    # Control points ---------------------------------------------------------------------------------------------------
    def get_control_points(self):
        return self.fixed_effects['control_points']

    def set_control_points(self, cp):
        self.fixed_effects['control_points'] = cp
        # self.number_of_control_points = len(cp)

    # Impulse ----------------------------------------------------------------------------------------------------------
    def get_impulse_t(self):
        return self.fixed_effects['impulse_t']

    def set_impulse_t(self, impulse_t):
        self.fixed_effects['impulse_t'] = impulse_t

    def get_A(self):
        return self.fixed_effects['A']

    def set_A(self, A):
        self.fixed_effects['A'] = A

    def get_B(self):
        return self.fixed_effects['B']

    def set_B(self, B):
        self.fixed_effects['B'] = B

    def get_C(self):
        return self.fixed_effects['C']

    def set_C(self, C):
        self.fixed_effects['C'] = C

    def get_initial_velocity(self):
        if (self.estimate_initial_velocity):
            return self.fixed_effects['initial_velocity']
        else:
            return np.zeros((self.number_of_control_points, self.dimension))

    def set_initial_velocity(self, initial_velocity):
        self.fixed_effects['initial_velocity'] = initial_velocity

    # Full fixed effects -----------------------------------------------------------------------------------------------
    def get_fixed_effects(self):
        out = {}
        if not self.freeze_template:
            for key, value in self.fixed_effects['template_data'].items():
                out[key] = value
        if not self.freeze_control_points:
            out['control_points'] = self.fixed_effects['control_points']
        out['impulse_t'] = self.fixed_effects['impulse_t']
        if self.estimate_initial_velocity:
            out['initial_velocity'] = self.fixed_effects['initial_velocity']
        out['A'] = self.fixed_effects['A']
        out['B'] = self.fixed_effects['B']
        out['C'] = self.fixed_effects['C']
        return out

    def set_fixed_effects(self, fixed_effects):
        if not self.freeze_template:
            template_data = {
                key: fixed_effects[key]
                for key in self.fixed_effects['template_data'].keys()
            }
            self.set_template_data(template_data)
        if not self.freeze_control_points:
            self.set_control_points(fixed_effects['control_points'])
        self.set_impulse_t(fixed_effects['impulse_t'])
        if self.estimate_initial_velocity:
            self.set_initial_velocity(fixed_effects['initial_velocity'])
        #if self.use_intensity_model:
        #    self.set_slope_image(fixed_effects['slope_image'])
        self.set_A(fixed_effects['A'])
        self.set_B(fixed_effects['B'])
        self.set_C(fixed_effects['C'])

    ####################################################################################################################
    ### Public methods:
    ####################################################################################################################

    # Compute the functional. Numpy input/outputs.
    def compute_log_likelihood(self,
                               dataset,
                               population_RER,
                               individual_RER,
                               mode='complete',
                               with_grad=False,
                               cur_iter=None):
        """
        Compute the log-likelihood of the dataset, given parameters fixed_effects and random effects realizations
        population_RER and indRER.

        :param dataset: LongitudinalDataset instance
        :param fixed_effects: Dictionary of fixed effects.
        :param population_RER: Dictionary of population random effects realizations.
        :param indRER: Dictionary of individual random effects realizations.
        :param with_grad: Flag that indicates wether the gradient should be returned as well.
        :return:
        """
        # Initialize: conversion from numpy to torch -------------------------------------------------------------------
        template_data, template_points, control_points, impulse_t, initial_velocity, A, B, C = self._fixed_effects_to_torch_tensors(
            with_grad)

        # Deform -------------------------------------------------------------------------------------------------------
        deformation_attachment, intensity_attachment, regularity, velocity_regularity, total_variation = self._compute_attachment_and_regularity(
            dataset, template_data, template_points, control_points, impulse_t,
            initial_velocity, A, B, C)

        # Compute gradient if needed -----------------------------------------------------------------------------------
        if with_grad:
            total = self.initial_velocity_weight * velocity_regularity + regularity + total_variation + intensity_attachment + deformation_attachment
            #total = self.initial_velocity_weight * velocity_regularity + regularity + intensity_attachment + deformation_attachment
            total.backward()

            gradient = {}
            # Template data.
            if not self.freeze_template:
                if 'landmark_points' in template_data.keys():
                    gradient['landmark_points'] = template_points[
                        'landmark_points'].grad
                if 'image_intensities' in template_data.keys():
                    gradient['image_intensities'] = template_data[
                        'image_intensities'].grad

                if self.use_sobolev_gradient and 'landmark_points' in gradient.keys(
                ):
                    gradient['landmark_points'] = self.sobolev_kernel.convolve(
                        template_data['landmark_points'].detach(),
                        template_data['landmark_points'].detach(),
                        gradient['landmark_points'].detach())

            # Control points
            if not self.freeze_control_points:
                gradient['control_points'] = control_points.grad

            # Initial velocity
            if self.estimate_initial_velocity:
                gradient['initial_velocity'] = initial_velocity.grad
                # print(initial_velocity)

            # Impulse t
            gradient['impulse_t'] = impulse_t.grad
            gradient['A'] = A.grad
            gradient['B'] = B.grad
            gradient['C'] = C.grad

            # Convert the gradient back to numpy.
            gradient = {
                key: value.data.cpu().numpy()
                for key, value in gradient.items()
            }

            #return deformation_attachment.detach().cpu().numpy() + intensity_attachment.detach().cpu().numpy(), \
            #       total_variation.detach().cpu().numpy() + regularity.detach().cpu().numpy() + self.initial_velocity_weight * velocity_regularity.detach().cpu().numpy(), gradient

            return deformation_attachment.detach().cpu().numpy() + intensity_attachment.detach().cpu().numpy(), \
                   regularity.detach().cpu().numpy() + self.initial_velocity_weight * velocity_regularity.detach().cpu().numpy(), gradient

        else:

            #eturn deformation_attachment.detach().cpu().numpy() + intensity_attachment.detach().cpu().numpy(), \
            #       total_variation.detach().cpu().numpy() + regularity.detach().cpu().numpy() + self.initial_velocity_weight * velocity_regularity.detach().cpu().numpy()
            return deformation_attachment.detach().cpu().numpy() + intensity_attachment.detach().cpu().numpy(), \
                   regularity.detach().cpu().numpy() + self.initial_velocity_weight * velocity_regularity.detach().cpu().numpy()

    ####################################################################################################################
    ### Private methods:
    ####################################################################################################################

    def _compute_attachment_and_regularity(self, dataset, template_data,
                                           template_points, control_points,
                                           impulse_t, initial_velocity, A, B,
                                           C):
        """
        Core part of the ComputeLogLikelihood methods. Fully torch.
        """

        # Initialize: cross-sectional dataset --------------------------------------------------------------------------
        target_times = dataset.times[0]
        target_objects = dataset.deformable_objects[0]

        # Deform -------------------------------------------------------------------------------------------------------
        self.acceleration_path.set_tmin(min(target_times))
        self.acceleration_path.set_tmax(max(target_times))
        self.acceleration_path.set_template_points_tmin(template_points)
        self.acceleration_path.set_control_points_tmin(control_points)
        self.acceleration_path.set_impulse_t(impulse_t)
        self.acceleration_path.set_initial_velocity(initial_velocity)
        self.acceleration_path.update()

        deformation_noise_variance = np.zeros(len(self.objects_noise_variance))
        for i in range(0, len(deformation_noise_variance)):
            deformation_noise_variance[
                i] = 1  #self.objects_noise_variance[i]*10

        #cuda0 = torch.device('cuda:0')
        #total_variation = torch.zeros([1], dtype=torch.float, requires_grad=True, device=cuda0)

        #total_variation_np = np.array([0])
        #total_variation = Variable(torch.from_numpy(total_variation_np).type(self.tensor_scalar_type), requires_grad=True)

        intensity_weight = 100
        deformation_weight = 1
        regularity_weight = 10000

        total_variation = 0.
        deformation_attachment = 0.
        intensity_attachment = 0.
        for j, (time, obj) in enumerate(zip(target_times, target_objects)):
            deformed_points = self.acceleration_path.get_template_points(time)

            linear_image_model = {}
            linear_image_model['image_intensities'] = A * torch.exp(
                -B * torch.exp(-C * time))
            deformed_data_withitensity = self.template.get_deformed_data(
                deformed_points, linear_image_model)
            #print(deformed_data_withitensity)
            #quit()
            intensity_attachment -= self.multi_object_attachment.compute_weighted_distance(
                deformed_data_withitensity, self.template, obj,
                self.objects_noise_variance)

            nointensity_model = {}
            nointensity_model['image_intensities'] = A * torch.exp(
                -B * torch.exp(-C * min(target_times)))
            deformed_data_noitensity = self.template.get_deformed_data(
                deformed_points, nointensity_model)
            deformation_attachment -= self.multi_object_attachment.compute_weighted_distance(
                deformed_data_noitensity, self.template, obj,
                deformation_noise_variance)

        if (self.dimension == 2):
            total_var_weight = 0.1
            # Compute total variation norm
            linear_image_model = {}
            linear_image_model['image_intensities'] = A * torch.exp(
                -B * torch.exp(-C * min(target_times)))
            height, width = linear_image_model['image_intensities'].size()
            dy = torch.abs(linear_image_model['image_intensities'][-1:, :] -
                           linear_image_model['image_intensities'][:-1, :])
            error = torch.norm(dy, 1)
            total_variation = (-(error / height) * total_var_weight)

        regularity = -self.acceleration_path.get_norm_squared(
        ) * regularity_weight
        #print(regularity)

        velocity_regularity = -self.acceleration_path.get_velocity_norm()

        deformation_attachment = deformation_attachment * deformation_weight
        intensity_attachment = intensity_attachment * intensity_weight

        # print(deformation_attachment)
        # print(intensity_attachment)
        # print(regularity)
        # print(velocity_regularity)
        # print(total_variation)

        return deformation_attachment, intensity_attachment, regularity, velocity_regularity, total_variation

    ####################################################################################################################
    ### Private utility methods:
    ####################################################################################################################

    def _fixed_effects_to_torch_tensors(self, with_grad):
        """
        Convert the fixed_effects into torch tensors.
        """
        # Template data.
        template_data = self.fixed_effects['template_data']
        template_data = {
            key:
            Variable(torch.from_numpy(value).type(self.tensor_scalar_type),
                     requires_grad=(not self.freeze_template and with_grad))
            for key, value in template_data.items()
        }

        # Template points.
        template_points = self.template.get_points()
        template_points = {
            key:
            Variable(torch.from_numpy(value).type(self.tensor_scalar_type),
                     requires_grad=(not self.freeze_template and with_grad))
            for key, value in template_points.items()
        }

        control_points = self.fixed_effects['control_points']
        control_points = Variable(
            torch.from_numpy(control_points).type(self.tensor_scalar_type),
            requires_grad=(not self.freeze_control_points and with_grad))

        # Impulse.
        impulse_t = self.fixed_effects['impulse_t']
        impulse_t = Variable(torch.from_numpy(impulse_t).type(
            self.tensor_scalar_type),
                             requires_grad=with_grad)

        A = self.fixed_effects['A']
        A = gaussian_filter(A, sigma=0.75)
        self.fixed_effects['A'] = A
        A = Variable(torch.from_numpy(A).type(self.tensor_scalar_type),
                     requires_grad=(with_grad))
        B = self.fixed_effects['B']
        B[B <= 0] = 1e-8
        B = gaussian_filter(B, sigma=0.75)
        self.fixed_effects['B'] = B
        B = Variable(torch.from_numpy(B).type(self.tensor_scalar_type),
                     requires_grad=(with_grad))
        C = self.fixed_effects['C']
        C[C <= 0] = 1e-8
        C = gaussian_filter(C, sigma=0.75)
        self.fixed_effects['C'] = C
        C = Variable(torch.from_numpy(C).type(self.tensor_scalar_type),
                     requires_grad=(with_grad))

        if (self.estimate_initial_velocity):
            initial_velocity = self.fixed_effects['initial_velocity']
            # Scale to unit norm
            norms = LA.norm(initial_velocity, axis=1) + 1e-6
            initial_velocity = initial_velocity / norms.reshape(-1, 1)
            # Now scale to the number of timesteps
            initial_velocity = initial_velocity / self.number_of_time_points
            self.fixed_effects['initial_velocity'] = initial_velocity
            initial_velocity = Variable(
                torch.from_numpy(initial_velocity).type(
                    self.tensor_scalar_type),
                requires_grad=with_grad)
        else:
            initial_velocity_np = np.zeros(
                (self.number_of_control_points, self.dimension))
            initial_velocity = Variable(
                torch.from_numpy(initial_velocity_np).type(
                    self.tensor_scalar_type),
                requires_grad=False)

        return template_data, template_points, control_points, impulse_t, initial_velocity, A, B, C

    ####################################################################################################################
    ### Writing methods:
    ####################################################################################################################

    def write(self,
              dataset,
              population_RER,
              individual_RER,
              output_dir,
              write_adjoint_parameters=False):
        self._write_model_predictions(output_dir, dataset,
                                      write_adjoint_parameters)
        self._write_model_parameters(output_dir)

    def _write_model_predictions(self,
                                 output_dir,
                                 dataset=None,
                                 write_adjoint_parameters=False):

        # Initialize ---------------------------------------------------------------------------------------------------
        template_data, template_points, control_points, impulse_t, initial_velocity, A, B, C = self._fixed_effects_to_torch_tensors(
            False)
        target_times = dataset.times[0]

        [T, number_of_control_points, dimension] = impulse_t.shape
        for t in range(0, T):

            out_image = image.Image(self.dimension)
            intensities = A * torch.exp(-B * torch.exp(-C * t))
            out_image.set_intensities(intensities.data.cpu().numpy())

            if (self.dimension == 2):
                # For PNG
                out_image.set_dtype(np.dtype(np.uint8))
                img_name = '%s__intensity_only_model_%0.3d.png' % (self.name,
                                                                   t)
                out_image.write(output_dir, img_name, should_rescale=True)
                # For TIF
                #out_image.set_dtype(np.dtype(np.float32))
                #img_name = '%s__intensity_only_model_%0.3d.tif' % (self.name, t)
                #out_image.write(output_dir, img_name, should_rescale=False)
            else:
                out_image.set_dtype(np.dtype(np.float32))
                img_name = '%s__intensity_only_model_%0.3d.nii' % (self.name,
                                                                   t)
                out_image.write(output_dir, img_name, should_rescale=False)

        # Deform -------------------------------------------------------------------------------------------------------
        self.acceleration_path.set_tmin(min(target_times))
        self.acceleration_path.set_tmax(max(target_times))
        self.acceleration_path.set_template_points_tmin(template_points)
        self.acceleration_path.set_control_points_tmin(control_points)
        self.acceleration_path.set_impulse_t(impulse_t)
        self.acceleration_path.set_initial_velocity(initial_velocity)
        self.acceleration_path.update()

        # Write --------------------------------------------------------------------------------------------------------
        self.acceleration_path.write(self.name, self.objects_name,
                                     self.objects_name_extension,
                                     self.template, template_data, A, B, C,
                                     output_dir, write_adjoint_parameters)

        # Model predictions.
        if dataset is not None:
            for j, time in enumerate(target_times):
                names = []
                for k, (object_name, object_extension) in enumerate(
                        zip(self.objects_name, self.objects_name_extension)):
                    name = '%s__Reconstruction__%s__%0.03f%s' % (
                        self.name, object_name, j, object_extension)
                    print(name)
                    names.append(name)
                deformed_points = self.acceleration_path.get_template_points(
                    time)
                linear_image_model = {}
                linear_image_model['image_intensities'] = A * torch.exp(
                    -B * torch.exp(-C * time))
                deformed_data = self.template.get_deformed_data(
                    deformed_points, linear_image_model)
                self.template.write(
                    output_dir, names, {
                        key: value.data.cpu().numpy()
                        for key, value in deformed_data.items()
                    })

        # Write the A, B, and C images
        A_im = image.Image(self.dimension)
        A_im.set_intensities(A.data.cpu().numpy())
        B_im = image.Image(self.dimension)
        B_im.set_intensities(B.data.cpu().numpy())
        C_im = image.Image(self.dimension)
        C_im.set_intensities(C.data.cpu().numpy())

        if (self.dimension == 2):
            A_im.set_dtype(np.dtype(np.float32))
            B_im.set_dtype(np.dtype(np.float32))
            C_im.set_dtype(np.dtype(np.float32))
            A_im.write(output_dir,
                       self.name + "__A_image.tif",
                       should_rescale=False)
            B_im.write(output_dir,
                       self.name + "__B_image.tif",
                       should_rescale=False)
            C_im.write(output_dir,
                       self.name + "__C_image.tif",
                       should_rescale=False)
        else:
            A_im.set_dtype(np.dtype(np.float32))
            B_im.set_dtype(np.dtype(np.float32))
            C_im.set_dtype(np.dtype(np.float32))
            A_im.write(output_dir,
                       self.name + "__A_image.nii",
                       should_rescale=False)
            B_im.write(output_dir,
                       self.name + "__B_image.nii",
                       should_rescale=False)
            C_im.write(output_dir,
                       self.name + "__C_image.nii",
                       should_rescale=False)

    def _write_model_parameters(self, output_dir):
        # Control points.
        write_2D_array(self.get_control_points(), output_dir,
                       self.name + "__EstimatedParameters__ControlPoints.txt")

        # Initial velocity
        write_3D_array(
            self.get_initial_velocity(), output_dir,
            self.name + "__EstimatedParameters__InitialVelocity.txt")

        # Write impulse
        impulse_t = self.acceleration_path.get_impulse_t()
        [T, number_of_control_points, dimension] = impulse_t.shape
        for i in range(0, T):
            out_name = '%s__EstimatedParameters__Impulse_t_%0.3d.txt' % (
                self.name, i)
            cur_impulse = impulse_t[i, :, :].data.cpu().numpy()
            write_3D_array(cur_impulse, output_dir, out_name)
    xml_parameters = XmlParameters()
    xml_parameters._read_model_xml(model_xml_path)

    deformetrica = Deformetrica(output_dir=output_dir)
    template_specifications, model_options, _ = deformetrica.further_initialization(
        'Shooting', xml_parameters.template_specifications, get_model_options(xml_parameters))

    """
    Load the template, control points, momenta, modulation matrix.
    """

    # Template.
    t_list, objects_name, objects_name_extension, _, _ = create_template_metadata(template_specifications)

    template = DeformableMultiObject(t_list)
    template_data = {key: torch.from_numpy(value).type(model_options['tensor_scalar_type'])
                     for key, value in template.get_data().items()}
    template_points = {key: torch.from_numpy(value).type(model_options['tensor_scalar_type'])
                       for key, value in template.get_points().items()}

    # Control points.
    control_points = read_2D_array(model_options['initial_control_points'])
    logger.info('>> Reading ' + str(len(control_points)) + ' initial control points from file: '
          + model_options['initial_control_points'])
    control_points = torch.from_numpy(control_points).type(model_options['tensor_scalar_type'])

    # Momenta.
    momenta = read_3D_array(model_options['initial_momenta'])
    logger.info('>> Reading initial momenta from file: ' + model_options['initial_momenta'])
    momenta = torch.from_numpy(momenta).type(model_options['tensor_scalar_type'])
コード例 #11
0
class GeodesicRegression(AbstractStatisticalModel):
    """
    Geodesic regression object class.
    """

    ####################################################################################################################
    ### Constructor:
    ####################################################################################################################

    def __init__(self):
        AbstractStatisticalModel.__init__(self)

        self.template = DeformableMultiObject()
        self.objects_name = []
        self.objects_name_extension = []
        self.objects_noise_variance = []

        self.multi_object_attachment = MultiObjectAttachment()
        self.geodesic = Geodesic()

        self.use_sobolev_gradient = True
        self.smoothing_kernel_width = None

        self.initial_cp_spacing = None
        self.number_of_objects = None
        self.number_of_control_points = None
        self.bounding_box = None

        # Dictionary of numpy arrays.
        self.fixed_effects['template_data'] = None
        self.fixed_effects['control_points'] = None
        self.fixed_effects['momenta'] = None

        self.freeze_template = False
        self.freeze_control_points = False

    ####################################################################################################################
    ### Encapsulation methods:
    ####################################################################################################################

    # Template data ----------------------------------------------------------------------------------------------------
    def get_template_data(self):
        return self.fixed_effects['template_data']

    def set_template_data(self, td):
        self.fixed_effects['template_data'] = td
        self.template.set_data(td)

    # Control points ---------------------------------------------------------------------------------------------------
    def get_control_points(self):
        return self.fixed_effects['control_points']

    def set_control_points(self, cp):
        self.fixed_effects['control_points'] = cp
        self.number_of_control_points = len(cp)

    # Momenta ----------------------------------------------------------------------------------------------------------
    def get_momenta(self):
        return self.fixed_effects['momenta']

    def set_momenta(self, mom):
        self.fixed_effects['momenta'] = mom

    # Full fixed effects -----------------------------------------------------------------------------------------------
    def get_fixed_effects(self):
        out = {}
        if not self.freeze_template:
            for key, value in self.fixed_effects['template_data'].items():
                out[key] = value
        if not self.freeze_control_points:
            out['control_points'] = self.fixed_effects['control_points']
        out['momenta'] = self.fixed_effects['momenta']
        return out

    def set_fixed_effects(self, fixed_effects):
        if not self.freeze_template:
            template_data = {key: fixed_effects[key] for key in self.fixed_effects['template_data'].keys()}
            self.set_template_data(template_data)
        if not self.freeze_control_points:
            self.set_control_points(fixed_effects['control_points'])
        self.set_momenta(fixed_effects['momenta'])

    ####################################################################################################################
    ### Public methods:
    ####################################################################################################################

    def update(self):
        """
        Final initialization steps.
        """

        self.template.update()
        self.number_of_objects = len(self.template.object_list)
        self.bounding_box = self.template.bounding_box

        self.set_template_data(self.template.get_data())

        if self.fixed_effects['control_points'] is None:
            self._initialize_control_points()
        else:
            self._initialize_bounding_box()

        if self.fixed_effects['momenta'] is None: self._initialize_momenta()

    # Compute the functional. Numpy input/outputs.
    def compute_log_likelihood(self, dataset, population_RER, individual_RER, mode='complete', with_grad=False):
        """
        Compute the log-likelihood of the dataset, given parameters fixed_effects and random effects realizations
        population_RER and indRER.

        :param dataset: LongitudinalDataset instance
        :param fixed_effects: Dictionary of fixed effects.
        :param population_RER: Dictionary of population random effects realizations.
        :param indRER: Dictionary of individual random effects realizations.
        :param with_grad: Flag that indicates wether the gradient should be returned as well.
        :return:
        """
        # Initialize: conversion from numpy to torch -------------------------------------------------------------------
        template_data, template_points, control_points, momenta = self._fixed_effects_to_torch_tensors(with_grad)

        # Deform -------------------------------------------------------------------------------------------------------
        attachment, regularity = self._compute_attachment_and_regularity(
            dataset, template_data, template_points, control_points, momenta)

        # Compute gradient if needed -----------------------------------------------------------------------------------
        if with_grad:
            total = regularity + attachment
            total.backward()

            gradient = {}
            # Template data.
            if not self.freeze_template:
                if 'landmark_points' in template_data.keys():
                    gradient['landmark_points'] = template_points['landmark_points'].grad
                if 'image_intensities' in template_data.keys():
                    gradient['image_intensities'] = template_data['image_intensities'].grad
                # for key, value in template_data.items():
                #     gradient[key] = value.grad

                if self.use_sobolev_gradient and 'landmark_points' in gradient.keys():
                    gradient['landmark_points'] = compute_sobolev_gradient(
                        gradient['landmark_points'], self.smoothing_kernel_width, self.template)

            # Control points and momenta.
            if not self.freeze_control_points: gradient['control_points'] = control_points.grad
            gradient['momenta'] = momenta.grad

            # Convert the gradient back to numpy.
            gradient = {key: value.data.cpu().numpy() for key, value in gradient.items()}

            return attachment.detach().cpu().numpy(), regularity.detach().cpu().numpy(), gradient

        else:
            return attachment.detach().cpu().numpy(), regularity.detach().cpu().numpy()

    def initialize_template_attributes(self, template_specifications):
        """
        Sets the Template, TemplateObjectsName, TemplateObjectsNameExtension, TemplateObjectsNorm,
        TemplateObjectsNormKernelType and TemplateObjectsNormKernelWidth attributes.
        """

        t_list, t_name, t_name_extension, t_noise_variance, t_multi_object_attachment = \
            create_template_metadata(template_specifications)

        self.template.object_list = t_list
        self.objects_name = t_name
        self.objects_name_extension = t_name_extension
        self.objects_noise_variance = t_noise_variance
        self.multi_object_attachment = t_multi_object_attachment

    ####################################################################################################################
    ### Private methods:
    ####################################################################################################################

    def _compute_attachment_and_regularity(self, dataset, template_data, template_points, control_points, momenta):
        """
        Core part of the ComputeLogLikelihood methods. Fully torch.
        """

        # Initialize: cross-sectional dataset --------------------------------------------------------------------------
        target_times = dataset.times[0]
        target_objects = dataset.deformable_objects[0]

        # Deform -------------------------------------------------------------------------------------------------------
        self.geodesic.set_tmin(min(target_times))
        self.geodesic.set_tmax(max(target_times))
        self.geodesic.set_template_points_t0(template_points)
        self.geodesic.set_control_points_t0(control_points)
        self.geodesic.set_momenta_t0(momenta)
        self.geodesic.update()

        attachment = 0.
        for j, (time, obj) in enumerate(zip(target_times, target_objects)):
            deformed_points = self.geodesic.get_template_points(time)
            deformed_data = self.template.get_deformed_data(deformed_points, template_data)
            attachment -= self.multi_object_attachment.compute_weighted_distance(
                deformed_data, self.template, obj, self.objects_noise_variance)
        regularity = - self.geodesic.get_norm_squared()

        return attachment, regularity

    def _initialize_control_points(self):
        """
        Initialize the control points fixed effect.
        """
        if not Settings().dense_mode:
            control_points = create_regular_grid_of_points(self.bounding_box, self.initial_cp_spacing)
        else:
            control_points = self.template.get_points()

        self.set_control_points(control_points)
        self.number_of_control_points = control_points.shape[0]
        logger.info('Set of ' + str(self.number_of_control_points) + ' control points defined.')

    def _initialize_momenta(self):
        """
        Initialize the momenta fixed effect.
        """
        momenta = np.zeros((self.number_of_control_points, Settings().dimension))
        self.set_momenta(momenta)

    def _initialize_bounding_box(self):
        """
        Initialize the bounding box. which tightly encloses all template objects and the atlas control points.
        Relevant when the control points are given by the user.
        """
        assert (self.number_of_control_points > 0)

        dimension = Settings().dimension
        control_points = self.get_control_points()

        for k in range(self.number_of_control_points):
            for d in range(dimension):
                if control_points[k, d] < self.bounding_box[d, 0]:
                    self.bounding_box[d, 0] = control_points[k, d]
                elif control_points[k, d] > self.bounding_box[d, 1]:
                    self.bounding_box[d, 1] = control_points[k, d]

    ####################################################################################################################
    ### Private utility methods:
    ####################################################################################################################

    def _fixed_effects_to_torch_tensors(self, with_grad):
        """
        Convert the fixed_effects into torch tensors.
        """
        # Template data.
        template_data = self.fixed_effects['template_data']
        template_data = {key: Variable(torch.from_numpy(value).type(Settings().tensor_scalar_type),
                                       requires_grad=(not self.freeze_template and with_grad))
                         for key, value in template_data.items()}

        # Template points.
        template_points = self.template.get_points()
        template_points = {key: Variable(torch.from_numpy(value).type(Settings().tensor_scalar_type),
                                         requires_grad=(not self.freeze_template and with_grad))
                           for key, value in template_points.items()}

        # Control points.
        if Settings().dense_mode:
            control_points = template_data
        else:
            control_points = self.fixed_effects['control_points']
            control_points = Variable(torch.from_numpy(control_points).type(Settings().tensor_scalar_type),
                                      requires_grad=((not self.freeze_control_points and with_grad)
                                                     or self.geodesic.get_kernel_type() == 'keops'))

        # Momenta.
        momenta = self.fixed_effects['momenta']
        momenta = Variable(torch.from_numpy(momenta).type(Settings().tensor_scalar_type), requires_grad=with_grad)

        return template_data, template_points, control_points, momenta

    ####################################################################################################################
    ### Writing methods:
    ####################################################################################################################

    def write(self, dataset=None, population_RER=None, individual_RER=None, write_adjoint_parameters=False):
        self._write_model_predictions(dataset, write_adjoint_parameters)
        self._write_model_parameters()

    def _write_model_predictions(self, dataset=None, write_adjoint_parameters=False):

        # Initialize ---------------------------------------------------------------------------------------------------
        template_data, template_points, control_points, momenta = self._fixed_effects_to_torch_tensors(False)
        target_times = dataset.times[0]

        # Deform -------------------------------------------------------------------------------------------------------
        self.geodesic.tmin = min(target_times)
        self.geodesic.tmax = max(target_times)
        self.geodesic.set_template_points_t0(template_points)
        self.geodesic.set_control_points_t0(control_points)
        self.geodesic.set_momenta_t0(momenta)
        self.geodesic.update()

        # Write --------------------------------------------------------------------------------------------------------
        # Geodesic flow.
        self.geodesic.write(self.name, self.objects_name, self.objects_name_extension, self.template, template_data,
                            write_adjoint_parameters)

        # Model predictions.
        if dataset is not None:
            for j, time in enumerate(target_times):
                names = []
                for k, (object_name, object_extension) in enumerate(
                        zip(self.objects_name, self.objects_name_extension)):
                    name = self.name + '__Reconstruction__' + object_name + '__tp_' + str(j) + ('__age_%.2f' % time) \
                           + object_extension
                    names.append(name)
                deformed_points = self.geodesic.get_template_points(time)
                deformed_data = self.template.get_deformed_data(deformed_points, template_data)
                self.template.write(names, {key: value.data.cpu().numpy() for key, value in deformed_data.items()})

    def _write_model_parameters(self):
        # Template.
        template_names = []
        for k in range(len(self.objects_name)):
            aux = self.name + '__EstimatedParameters__Template_' + self.objects_name[k] + '__tp_' \
                  + str(self.geodesic.backward_exponential.number_of_time_points - 1) \
                  + ('__age_%.2f' % self.geodesic.t0) + self.objects_name_extension[k]
            template_names.append(aux)
        self.template.write(template_names)

        # Control points.
        write_2D_array(self.get_control_points(), self.name + "__EstimatedParameters__ControlPoints.txt")

        # Momenta.
        write_3D_array(self.get_momenta(), self.name + "__EstimatedParameters__Momenta.txt")
コード例 #12
0
class AccelerationRegression(AbstractStatisticalModel):
    """
    Acceleration regression object class.
    """

    ####################################################################################################################
    ### Constructor:
    ####################################################################################################################

    def __init__(self,
                 template_specifications,
                 dimension=default.dimension,
                 tensor_scalar_type=default.tensor_scalar_type,
                 tensor_integer_type=default.tensor_integer_type,
                 number_of_threads=default.number_of_threads,
                 deformation_kernel_type=default.deformation_kernel_type,
                 deformation_kernel_width=default.deformation_kernel_width,
                 deformation_kernel_device=default.deformation_kernel_device,
                 shoot_kernel_type=default.shoot_kernel_type,
                 number_of_time_points=default.number_of_time_points,
                 freeze_template=default.freeze_template,
                 use_sobolev_gradient=default.use_sobolev_gradient,
                 smoothing_kernel_width=default.smoothing_kernel_width,
                 estimate_initial_velocity=default.estimate_initial_velocity,
                 initial_velocity_weight=default.initial_velocity_weight,
                 regularity_weight=default.regularity_weight,
                 data_weight=default.data_weight,
                 initial_control_points=default.initial_control_points,
                 freeze_control_points=default.freeze_control_points,
                 initial_cp_spacing=default.initial_cp_spacing,
                 initial_impulse_t=None,
                 initial_velocity=None,
                 **kwargs):

        AbstractStatisticalModel.__init__(self, name='AccelerationRegression')

        # Global-like attributes
        self.dimension = dimension
        self.tensor_scalar_type = tensor_scalar_type
        self.tensor_integer_type = tensor_integer_type
        self.number_of_threads = number_of_threads

        # Declare model structure
        self.fixed_effects['template_data'] = None
        self.fixed_effects['control_points'] = None

        self.freeze_template = freeze_template
        self.freeze_control_points = freeze_control_points

        # Deformation
        self.acceleration_path = AccelerationPath(
            kernel=kernel_factory.factory(deformation_kernel_type,
                                          deformation_kernel_width,
                                          device=deformation_kernel_device),
            shoot_kernel_type=shoot_kernel_type,
            number_of_time_points=number_of_time_points)

        # Template
        (object_list, self.objects_name, self.objects_name_extension,
         self.multi_object_attachment) = create_template_metadata(
             template_specifications, self.dimension)

        self.template = DeformableMultiObject(object_list)
        self.template.update()

        self.number_of_objects = len(self.template.object_list)

        self.use_sobolev_gradient = use_sobolev_gradient
        self.smoothing_kernel_width = smoothing_kernel_width
        if self.use_sobolev_gradient:
            self.sobolev_kernel = kernel_factory.factory(
                deformation_kernel_type,
                smoothing_kernel_width,
                device=deformation_kernel_device)

        # Template data
        self.fixed_effects['template_data'] = self.template.get_data()

        # Control points
        self.fixed_effects['control_points'] = initialize_control_points(
            initial_control_points, self.template, initial_cp_spacing,
            deformation_kernel_width, self.dimension, False)

        self.estimate_initial_velocity = estimate_initial_velocity
        self.initial_velocity_weight = initial_velocity_weight
        self.regularity_weight = regularity_weight
        self.data_weight = data_weight

        self.number_of_control_points = len(
            self.fixed_effects['control_points'])
        self.number_of_time_points = number_of_time_points

        # Impulse
        self.fixed_effects['impulse_t'] = initialize_impulse(
            initial_impulse_t, self.number_of_time_points,
            self.number_of_control_points, self.dimension)
        if (self.estimate_initial_velocity):
            self.fixed_effects[
                'initial_velocity'] = initialize_initial_velocity(
                    initial_velocity, self.number_of_control_points,
                    self.dimension)

    ####################################################################################################################
    ### Encapsulation methods:
    ####################################################################################################################

    # Template data ----------------------------------------------------------------------------------------------------
    def get_template_data(self):
        return self.fixed_effects['template_data']

    def set_template_data(self, td):
        self.fixed_effects['template_data'] = td
        self.template.set_data(td)

    # Control points ---------------------------------------------------------------------------------------------------
    def get_control_points(self):
        return self.fixed_effects['control_points']

    def set_control_points(self, cp):
        self.fixed_effects['control_points'] = cp

    # Impulse ----------------------------------------------------------------------------------------------------------
    def get_impulse_t(self):
        return self.fixed_effects['impulse_t']

    def set_impulse_t(self, impulse_t):
        self.fixed_effects['impulse_t'] = impulse_t

    def get_initial_velocity(self):
        if (self.estimate_initial_velocity):
            return self.fixed_effects['initial_velocity']
        else:
            return np.zeros((self.number_of_control_points, self.dimension))

    def set_initial_velocity(self, initial_velocity):
        self.fixed_effects['initial_velocity'] = initial_velocity

    # Full fixed effects -----------------------------------------------------------------------------------------------
    def get_fixed_effects(self):
        out = {}
        if not self.freeze_template:
            for key, value in self.fixed_effects['template_data'].items():
                out[key] = value
        if not self.freeze_control_points:
            out['control_points'] = self.fixed_effects['control_points']
        out['impulse_t'] = self.fixed_effects['impulse_t']
        if self.estimate_initial_velocity:
            out['initial_velocity'] = self.fixed_effects['initial_velocity']
        return out

    def set_fixed_effects(self, fixed_effects):
        if not self.freeze_template:
            template_data = {
                key: fixed_effects[key]
                for key in self.fixed_effects['template_data'].keys()
            }
            self.set_template_data(template_data)
        if not self.freeze_control_points:
            self.set_control_points(fixed_effects['control_points'])
        self.set_impulse_t(fixed_effects['impulse_t'])
        if self.estimate_initial_velocity:
            self.set_initial_velocity(fixed_effects['initial_velocity'])

    ####################################################################################################################
    ### Public methods:
    ####################################################################################################################

    # Compute the functional. Numpy input/outputs.
    def compute_log_likelihood(self, dataset, with_grad=False):
        """
        Compute the log-likelihood of the dataset
        :param dataset: LongitudinalDataset instance
        :param with_grad: Flag that indicates wether the gradient should be returned as well.
        :return:
        """
        # Initialize: conversion from numpy to torch -------------------------------------------------------------------
        template_data, template_points, control_points, impulse_t, initial_velocity = self._fixed_effects_to_torch_tensors(
            with_grad)

        # Deform -------------------------------------------------------------------------------------------------------
        data_attachment, regularity, velocity_regularity = self._compute_attachment_and_regularity(
            dataset, template_data, template_points, control_points, impulse_t,
            initial_velocity)

        # Compute gradient if needed -----------------------------------------------------------------------------------
        if with_grad:

            total = self.initial_velocity_weight * velocity_regularity + self.regularity_weight * regularity + self.data_weight * data_attachment
            total.backward()

            gradient = {}
            # Template data.
            if not self.freeze_template:
                if 'landmark_points' in template_data.keys():
                    gradient['landmark_points'] = template_points[
                        'landmark_points'].grad
                if 'image_intensities' in template_data.keys():
                    gradient['image_intensities'] = template_data[
                        'image_intensities'].grad

                if self.use_sobolev_gradient and 'landmark_points' in gradient.keys(
                ):
                    gradient['landmark_points'] = self.sobolev_kernel.convolve(
                        template_data['landmark_points'].detach(),
                        template_data['landmark_points'].detach(),
                        gradient['landmark_points'].detach())

            # Control points
            if not self.freeze_control_points:
                gradient['control_points'] = control_points.grad

            # Initial velocity
            if self.estimate_initial_velocity:
                gradient['initial_velocity'] = initial_velocity.grad

            # Impulse t
            gradient['impulse_t'] = impulse_t.grad

            # Convert the gradient back to numpy.
            gradient = {
                key: value.data.cpu().numpy()
                for key, value in gradient.items()
            }

            return self.data_weight*data_attachment.detach().cpu().numpy(), \
                   self.regularity_weight*regularity.detach().cpu().numpy() + self.initial_velocity_weight*velocity_regularity.detach().cpu().numpy(), gradient

        else:

            return self.data_weight * data_attachment.detach().cpu().numpy(), \
                   self.regularity_weight * regularity.detach().cpu().numpy() + self.initial_velocity_weight * velocity_regularity.detach().cpu().numpy()

    ####################################################################################################################
    ### Private methods:
    ####################################################################################################################

    def _compute_attachment_and_regularity(self, dataset, template_data,
                                           template_points, control_points,
                                           impulse_t, initial_velocity):
        """
        Core part of the ComputeLogLikelihood methods. Fully torch.
        """

        # Initialize: cross-sectional dataset --------------------------------------------------------------------------
        target_times = dataset.times[0]
        target_objects = dataset.deformable_objects[0]

        # Deform -------------------------------------------------------------------------------------------------------
        self.acceleration_path.set_tmin(min(target_times))
        self.acceleration_path.set_tmax(max(target_times))
        self.acceleration_path.set_template_points_tmin(template_points)
        self.acceleration_path.set_control_points_tmin(control_points)
        self.acceleration_path.set_impulse_t(impulse_t)
        self.acceleration_path.set_initial_velocity(initial_velocity)
        self.acceleration_path.update()

        data_attachment = 0.0
        for j, (time, obj) in enumerate(zip(target_times, target_objects)):
            deformed_points = self.acceleration_path.get_template_points(time)
            deformed_data = self.template.get_deformed_data(
                deformed_points, template_data)

            data_attachment += self.multi_object_attachment.compute_weighted_distance(
                deformed_data, self.template, obj)

        regularity = self.acceleration_path.get_norm_squared()
        velocity_regularity = self.acceleration_path.get_velocity_norm()

        return data_attachment, regularity, velocity_regularity

    ####################################################################################################################
    ### Private utility methods:
    ####################################################################################################################

    def _fixed_effects_to_torch_tensors(self, with_grad):
        """
        Convert the fixed_effects into torch tensors.
        """
        # Template data
        template_data = self.fixed_effects['template_data']
        template_data = {
            key:
            Variable(torch.from_numpy(value).type(self.tensor_scalar_type),
                     requires_grad=(not self.freeze_template and with_grad))
            for key, value in template_data.items()
        }

        # Template points
        template_points = self.template.get_points()
        template_points = {
            key:
            Variable(torch.from_numpy(value).type(self.tensor_scalar_type),
                     requires_grad=(not self.freeze_template and with_grad))
            for key, value in template_points.items()
        }

        control_points = self.fixed_effects['control_points']
        control_points = Variable(
            torch.from_numpy(control_points).type(self.tensor_scalar_type),
            requires_grad=(not self.freeze_control_points and with_grad))

        # Impulse
        impulse_t = self.fixed_effects['impulse_t']
        impulse_t = Variable(torch.from_numpy(impulse_t).type(
            self.tensor_scalar_type),
                             requires_grad=with_grad)

        if (self.estimate_initial_velocity):
            initial_velocity = self.fixed_effects['initial_velocity']
            # Scale to unit norm
            norms = LA.norm(initial_velocity, axis=1) + 1e-6
            initial_velocity = initial_velocity / norms.reshape(-1, 1)
            # Now scale to the number of timesteps
            initial_velocity = initial_velocity / self.number_of_time_points
            self.fixed_effects['initial_velocity'] = initial_velocity
            initial_velocity = Variable(
                torch.from_numpy(initial_velocity).type(
                    self.tensor_scalar_type),
                requires_grad=with_grad)
        else:
            initial_velocity_np = np.zeros(
                (self.number_of_control_points, self.dimension))
            initial_velocity = Variable(
                torch.from_numpy(initial_velocity_np).type(
                    self.tensor_scalar_type),
                requires_grad=False)

        return template_data, template_points, control_points, impulse_t, initial_velocity

    ####################################################################################################################
    ### Writing methods:
    ####################################################################################################################

    def write(self, dataset, output_dir):
        self._write_model_predictions(output_dir, dataset)
        self._write_model_parameters(output_dir)

    def _write_model_predictions(self, output_dir, dataset=None):

        # Initialize ---------------------------------------------------------------------------------------------------
        template_data, template_points, control_points, impulse_t, initial_velocity = self._fixed_effects_to_torch_tensors(
            False)
        target_times = dataset.times[0]

        # Deform -------------------------------------------------------------------------------------------------------
        self.acceleration_path.set_tmin(min(target_times))
        self.acceleration_path.set_tmax(max(target_times))
        self.acceleration_path.set_template_points_tmin(template_points)
        self.acceleration_path.set_control_points_tmin(control_points)
        self.acceleration_path.set_impulse_t(impulse_t)
        self.acceleration_path.set_initial_velocity(initial_velocity)
        self.acceleration_path.update()

        # Write --------------------------------------------------------------------------------------------------------
        self.acceleration_path.write(self.name, self.objects_name,
                                     self.objects_name_extension,
                                     self.template, template_data, output_dir)

        # Model predictions
        if dataset is not None:
            for j, time in enumerate(target_times):
                names = []
                for k, (object_name, object_extension) in enumerate(
                        zip(self.objects_name, self.objects_name_extension)):
                    name = '%s__Reconstruction__%s__%0.03f%s' % (
                        self.name, object_name, j, object_extension)
                    print(name)
                    names.append(name)
                deformed_points = self.acceleration_path.get_template_points(
                    time)
                deformed_data = self.template.get_deformed_data(
                    deformed_points, template_data)
                self.template.write(
                    output_dir, names, {
                        key: value.data.cpu().numpy()
                        for key, value in deformed_data.items()
                    })

    def _write_model_parameters(self, output_dir):
        # Control points
        write_2D_array(self.get_control_points(), output_dir,
                       self.name + "__EstimatedParameters__ControlPoints.txt")

        # Initial velocity
        write_3D_array(
            self.get_initial_velocity(), output_dir,
            self.name + "__EstimatedParameters__InitialVelocity.txt")

        # Write impulse
        impulse_t = self.acceleration_path.get_impulse_t()
        [T, number_of_control_points, dimension] = impulse_t.shape
        for i in range(0, T):
            out_name = '%s__EstimatedParameters__Impulse_t_%0.3d.txt' % (
                self.name, i)
            cur_impulse = impulse_t[i, :, :].data.cpu().numpy()
            write_3D_array(cur_impulse, output_dir, out_name)
コード例 #13
0
class BayesianAtlas(AbstractStatisticalModel):
    """
    Bayesian atlas object class.
    """

    ####################################################################################################################
    ### Constructor:
    ####################################################################################################################

    def __init__(self):
        AbstractStatisticalModel.__init__(self)

        self.template = DeformableMultiObject()
        self.objects_name = []
        self.objects_name_extension = []
        self.objects_noise_dimension = []

        self.multi_object_attachment = None
        self.exponential = Exponential()

        self.use_sobolev_gradient = True
        self.smoothing_kernel_width = None

        self.initial_cp_spacing = None
        self.number_of_objects = None
        self.number_of_control_points = None
        self.bounding_box = None

        # Dictionary of numpy arrays.
        self.fixed_effects['template_data'] = None
        self.fixed_effects['control_points'] = None
        self.fixed_effects['covariance_momenta_inverse'] = None
        self.fixed_effects['noise_variance'] = None

        # Dictionary of probability distributions.
        self.priors['covariance_momenta'] = InverseWishartDistribution()
        self.priors['noise_variance'] = MultiScalarInverseWishartDistribution()

        # Dictionary of probability distributions.
        self.individual_random_effects['momenta'] = NormalDistribution()

        self.freeze_template = False
        self.freeze_control_points = False

    ####################################################################################################################
    ### Encapsulation methods:
    ####################################################################################################################

    # Template data ----------------------------------------------------------------------------------------------------
    def get_template_data(self):
        return self.fixed_effects['template_data']

    def set_template_data(self, td):
        self.fixed_effects['template_data'] = td
        self.template.set_data(td)

    # Control points ---------------------------------------------------------------------------------------------------
    def get_control_points(self):
        return self.fixed_effects['control_points']

    def set_control_points(self, cp):
        self.fixed_effects['control_points'] = cp
        self.number_of_control_points = len(cp)

    # Covariance momenta inverse ---------------------------------------------------------------------------------------
    def get_covariance_momenta_inverse(self):
        return self.fixed_effects['covariance_momenta_inverse']

    def set_covariance_momenta_inverse(self, cmi):
        self.fixed_effects['covariance_momenta_inverse'] = cmi
        self.individual_random_effects['momenta'].set_covariance_inverse(cmi)

    def set_covariance_momenta(self, cm):
        self.set_covariance_momenta_inverse(np.linalg.inv(cm))

    # Noise variance ---------------------------------------------------------------------------------------------------
    def get_noise_variance(self):
        return self.fixed_effects['noise_variance']

    def set_noise_variance(self, nv):
        self.fixed_effects['noise_variance'] = nv

    # Full fixed effects -----------------------------------------------------------------------------------------------
    def get_fixed_effects(self):
        out = {}
        if not self.freeze_template:
            for key, value in self.fixed_effects['template_data'].items():
                out[key] = value
        if not self.freeze_control_points:
            out['control_points'] = self.fixed_effects['control_points']
        return out

    def set_fixed_effects(self, fixed_effects):
        if not self.freeze_template:
            template_data = {
                key: fixed_effects[key]
                for key in self.fixed_effects['template_data'].keys()
            }
            self.set_template_data(template_data)
        if not self.freeze_control_points:
            self.set_control_points(fixed_effects['control_points'])

    ####################################################################################################################
    ### Public methods:
    ####################################################################################################################

    def update(self):
        """
        Final initialization steps.
        """
        self.number_of_objects = len(self.template.object_list)
        self.bounding_box = self.template.bounding_box

        self.set_template_data(self.template.get_data())

        if self.fixed_effects['control_points'] is None:
            self._initialize_control_points()
        else:
            self._initialize_bounding_box()

        self._initialize_momenta()
        self._initialize_noise_variance()

    def compute_log_likelihood(self,
                               dataset,
                               population_RER,
                               individual_RER,
                               mode='complete',
                               with_grad=False):
        """
        Compute the log-likelihood of the dataset, given parameters fixed_effects and random effects realizations
        population_RER and indRER.
        Start by updating the class 1 fixed effects.

        :param dataset: LongitudinalDataset instance
        :param population_RER: Dictionary of population random effects realizations.
        :param individual_RER: Dictionary of individual random effects realizations.
        :param with_grad: Flag that indicates wether the gradient should be returned as well.
        :return:
        """

        # Initialize: conversion from numpy to torch -------------------------------------------------------------------
        template_data, template_points, control_points = self._fixed_effects_to_torch_tensors(
            with_grad)
        momenta = self._individual_RER_to_torch_tensors(
            individual_RER, with_grad and mode == 'complete')

        # Deform, update, compute metrics ------------------------------------------------------------------------------
        residuals = self._compute_residuals(dataset, template_data,
                                            template_points, control_points,
                                            momenta)

        # Update the fixed effects only if the user asked for the complete log likelihood.
        if mode == 'complete':
            sufficient_statistics = self.compute_sufficient_statistics(
                dataset, population_RER, individual_RER, residuals=residuals)
            self.update_fixed_effects(dataset, sufficient_statistics)

        # Compute the attachment, with the updated noise variance parameter in the 'complete' mode.
        attachments = self._compute_individual_attachments(residuals)
        attachment = torch.sum(attachments)

        # Compute the regularity terms according to the mode.
        regularity = 0.0
        if mode == 'complete':
            regularity = self._compute_random_effects_regularity(momenta)
            regularity += self._compute_class1_priors_regularity()
        if mode in ['complete', 'class2']:
            regularity += self._compute_class2_priors_regularity(
                template_data, control_points)

        # Compute gradient if needed -----------------------------------------------------------------------------------
        if with_grad:
            total = regularity + attachment
            total.backward()

            gradient = {}
            gradient_numpy = {}

            # Template data.
            if not self.freeze_template:
                if 'landmark_points' in template_data.keys():
                    gradient['landmark_points'] = template_points[
                        'landmark_points'].grad
                if 'image_intensities' in template_data.keys():
                    gradient['image_intensities'] = template_data[
                        'image_intensities'].grad
                # for key, value in template_data.items():
                #     if value.grad is not None:
                #         gradient[key] = value.grad

                if self.use_sobolev_gradient and 'landmark_points' in gradient.keys(
                ):
                    gradient['landmark_points'] = compute_sobolev_gradient(
                        gradient['landmark_points'],
                        self.smoothing_kernel_width, self.template)

            # Control points.
            if not self.freeze_control_points:
                gradient['control_points'] = control_points.grad

            # Individual effects.
            if mode == 'complete': gradient['momenta'] = momenta.grad

            # Convert to numpy.
            for (key, value) in gradient.items():
                gradient_numpy[key] = value.data.cpu().numpy()

            # Return as appropriate.
            if mode in ['complete', 'class2']:
                return attachment.detach().cpu().numpy(), regularity.detach(
                ).cpu().numpy(), gradient_numpy
            elif mode == 'model':
                return attachments.detach().cpu().numpy(), gradient_numpy

        else:
            if mode in ['complete', 'class2']:
                return attachment.detach().cpu().numpy(), regularity.detach(
                ).cpu().numpy()
            elif mode == 'model':
                return attachments.detach().cpu().numpy()

    def compute_sufficient_statistics(self,
                                      dataset,
                                      population_RER,
                                      individual_RER,
                                      residuals=None):
        """
        Compute the model sufficient statistics.
        """
        if residuals is None:
            # Initialize: conversion from numpy to torch ---------------------------------------------------------------
            # Template data.
            template_data = self.fixed_effects['template_data']
            template_data = Variable(torch.from_numpy(template_data).type(
                Settings().tensor_scalar_type),
                                     requires_grad=False)
            # Control points.
            control_points = self.fixed_effects['control_points']
            control_points = Variable(torch.from_numpy(control_points).type(
                Settings().tensor_scalar_type),
                                      requires_grad=False)
            # Momenta.
            momenta = individual_RER['momenta']
            momenta = Variable(torch.from_numpy(momenta).type(
                Settings().tensor_scalar_type),
                               requires_grad=False)

            # Compute residuals ----------------------------------------------------------------------------------------
            residuals = [
                torch.sum(residuals_i)
                for residuals_i in self._compute_residuals(
                    dataset, template_data, control_points, momenta)
            ]

        # Compute sufficient statistics --------------------------------------------------------------------------------
        sufficient_statistics = {}

        # Empirical momenta covariance.
        momenta = individual_RER['momenta']
        sufficient_statistics['S1'] = np.zeros(
            (momenta[0].size, momenta[0].size))
        for i in range(dataset.number_of_subjects):
            sufficient_statistics['S1'] += np.dot(
                momenta[i].reshape(-1, 1), momenta[i].reshape(-1,
                                                              1).transpose())

        # Empirical residuals variances, for each object.
        sufficient_statistics['S2'] = np.zeros((self.number_of_objects, ))
        for k in range(self.number_of_objects):
            sufficient_statistics['S2'][k] = residuals[k].detach().cpu().numpy(
            )

        # Finalization -------------------------------------------------------------------------------------------------
        return sufficient_statistics

    def update_fixed_effects(self, dataset, sufficient_statistics):
        """
        Updates the fixed effects based on the sufficient statistics, maximizing the likelihood.
        """
        # Covariance of the momenta update.
        prior_scale_matrix = self.priors['covariance_momenta'].scale_matrix
        prior_dof = self.priors['covariance_momenta'].degrees_of_freedom
        covariance_momenta = sufficient_statistics['S1'] + prior_dof * np.transpose(prior_scale_matrix) \
                                                           / (dataset.number_of_subjects + prior_dof)
        self.set_covariance_momenta(covariance_momenta)

        # Variance of the residual noise update.
        noise_variance = np.zeros((self.number_of_objects, ))
        prior_scale_scalars = self.priors['noise_variance'].scale_scalars
        prior_dofs = self.priors['noise_variance'].degrees_of_freedom
        for k in range(self.number_of_objects):
            noise_variance[k] = (sufficient_statistics['S2'] + prior_scale_scalars[k] * prior_dofs[k]) \
                                / float(dataset.number_of_subjects * self.objects_noise_dimension[k] + prior_dofs[k])
        self.set_noise_variance(noise_variance)

    def initialize_template_attributes(self, template_specifications):
        """
        Sets the Template, TemplateObjectsName, TemplateObjectsNameExtension, TemplateObjectsNorm,
        TemplateObjectsNormKernelType and TemplateObjectsNormKernelWidth attributes.
        """

        t_list, t_name, t_name_extension, t_noise_variance, t_multi_object_attachment = \
            create_template_metadata(template_specifications)

        self.template.object_list = t_list
        self.objects_name = t_name
        self.objects_name_extension = t_name_extension
        self.multi_object_attachment = t_multi_object_attachment

        self.template.update()
        self.objects_noise_dimension = compute_noise_dimension(
            self.template, self.multi_object_attachment)

    ####################################################################################################################
    ### Private methods:
    ####################################################################################################################

    def _compute_attachment(self, residuals):
        """
        Fully torch.
        """
        return torch.sum(self._compute_individual_attachments(residuals))

    def _compute_individual_attachments(self, residuals):
        """
        Fully torch.
        """
        number_of_subjects = len(residuals)
        attachments = Variable(torch.zeros(
            (number_of_subjects, )).type(Settings().tensor_scalar_type),
                               requires_grad=False)
        for i in range(number_of_subjects):
            attachments[i] = -0.5 * torch.sum(residuals[i] / Variable(
                torch.from_numpy(self.fixed_effects['noise_variance']).type(
                    Settings().tensor_scalar_type),
                requires_grad=False))
        return attachments

    def _compute_random_effects_regularity(self, momenta):
        """
        Fully torch.
        """
        number_of_subjects = momenta.shape[0]
        regularity = 0.0

        # Momenta random effect.
        for i in range(number_of_subjects):
            regularity += self.individual_random_effects[
                'momenta'].compute_log_likelihood_torch(momenta[i])

        # Noise random effect.
        for k in range(self.number_of_objects):
            regularity -= 0.5 * self.objects_noise_dimension[k] * number_of_subjects \
                          * math.log(self.fixed_effects['noise_variance'][k])

        return regularity

    def _compute_class1_priors_regularity(self):
        """
        Fully torch.
        Prior terms of the class 1 fixed effects, i.e. those for which we know a close-form update. No derivative
        wrt those fixed effects will therefore be necessary.
        """
        regularity = 0.0

        # Covariance momenta prior.
        regularity += self.priors['covariance_momenta'].compute_log_likelihood(
            self.fixed_effects['covariance_momenta_inverse'])

        # Noise variance prior.
        regularity += self.priors['noise_variance'].compute_log_likelihood(
            self.fixed_effects['noise_variance'])

        return regularity

    def _compute_class2_priors_regularity(self, template_data, control_points):
        """
        Fully torch.
        Prior terms of the class 2 fixed effects, i.e. those for which we do not know a close-form update. Derivative
        wrt those fixed effects will therefore be necessary.
        """
        regularity = 0.0

        # Prior on template_data fixed effects (if not frozen). None implemented yet TODO.
        if not self.freeze_template:
            regularity += 0.0

        # Prior on control_points fixed effects (if not frozen). None implemented yet TODO.
        if not self.freeze_control_points:
            regularity += 0.0

        return regularity

    def _compute_residuals(self, dataset, template_data, template_points,
                           control_points, momenta):
        """
        Core part of the ComputeLogLikelihood methods. Fully torch.
        """

        # Initialize: cross-sectional dataset --------------------------------------------------------------------------
        targets = dataset.deformable_objects
        targets = [target[0] for target in targets]

        # Deform -------------------------------------------------------------------------------------------------------
        residuals = []

        self.exponential.set_initial_template_points(template_points)
        self.exponential.set_initial_control_points(control_points)

        for i, target in enumerate(targets):
            self.exponential.set_initial_momenta(momenta[i])
            self.exponential.update()
            deformed_points = self.exponential.get_template_points()
            deformed_data = self.template.get_deformed_data(
                deformed_points, template_data)
            residuals.append(
                self.multi_object_attachment.compute_distances(
                    deformed_data, self.template, target))

        return residuals

    def _initialize_control_points(self):
        """
        Initialize the control points fixed effect.
        """
        if not Settings().dense_mode:
            control_points = create_regular_grid_of_points(
                self.bounding_box, self.initial_cp_spacing)
        else:
            control_points = self.template.get_points()

        self.set_control_points(control_points)
        self.number_of_control_points = control_points.shape[0]
        logger.info('Set of ' + str(self.number_of_control_points) +
                    ' control points defined.')

    def _initialize_momenta(self):
        """
        Initialize the momenta fixed effect.
        """
        self.individual_random_effects['momenta'].mean = \
            np.zeros((self.number_of_control_points * Settings().dimension,))
        self._initialize_covariance(
        )  # Initialize the prior and the momenta random effect.

    def _initialize_covariance(self):
        """
        Initialize the scale matrix of the inverse wishart prior, as well as the covariance matrix of the normal
        random effect.
        """
        assert self.exponential.kernel.kernel_width is not None
        dimension = Settings().dimension  # Shorthand.
        rkhs_matrix = np.zeros((self.number_of_control_points * dimension,
                                self.number_of_control_points * dimension))
        for i in range(self.number_of_control_points):
            for j in range(self.number_of_control_points):
                cp_i = self.fixed_effects['control_points'][i, :]
                cp_j = self.fixed_effects['control_points'][j, :]
                kernel_distance = math.exp(-np.sum(
                    (cp_j - cp_i)**2) / (self.exponential.kernel.kernel_width**
                                         2))  # Gaussian kernel.
                for d in range(dimension):
                    rkhs_matrix[dimension * i + d,
                                dimension * j + d] = kernel_distance
                    rkhs_matrix[dimension * j + d,
                                dimension * i + d] = kernel_distance
        self.priors['covariance_momenta'].scale_matrix = np.linalg.inv(
            rkhs_matrix)
        self.set_covariance_momenta_inverse(rkhs_matrix)

    def _initialize_noise_variance(self):
        self.set_noise_variance(
            np.asarray(self.priors['noise_variance'].scale_scalars))

    def _initialize_bounding_box(self):
        """
        Initialize the bounding box. which tightly encloses all template objects and the atlas control points.
        Relevant when the control points are given by the user.
        """

        assert (self.number_of_control_points > 0)

        dimension = Settings().dimension
        control_points = self.get_control_points()

        for k in range(self.number_of_control_points):
            for d in range(dimension):
                if control_points[k, d] < self.bounding_box[d, 0]:
                    self.bounding_box[d, 0] = control_points[k, d]
                elif control_points[k, d] > self.bounding_box[d, 1]:
                    self.bounding_box[d, 1] = control_points[k, d]

    ####################################################################################################################
    ### Private utility methods:
    ####################################################################################################################

    def _fixed_effects_to_torch_tensors(self, with_grad):
        """
        Convert the input fixed_effects into torch tensors.
        """
        # Template data.
        template_data = self.fixed_effects['template_data']
        template_data = {
            key: Variable(
                torch.from_numpy(value).type(Settings().tensor_scalar_type),
                requires_grad=(not self.freeze_template and with_grad))
            for key, value in template_data.items()
        }

        # Template points.
        template_points = self.template.get_points()
        template_points = {
            key: Variable(
                torch.from_numpy(value).type(Settings().tensor_scalar_type),
                requires_grad=(not self.freeze_template and with_grad))
            for key, value in template_points.items()
        }
        # Control points.
        if Settings().dense_mode:
            control_points = template_data
        else:
            control_points = self.fixed_effects['control_points']
            control_points = Variable(
                torch.from_numpy(control_points).type(
                    Settings().tensor_scalar_type),
                requires_grad=((not self.freeze_control_points) and with_grad))

        return template_data, template_points, control_points

    def _individual_RER_to_torch_tensors(self, individual_RER, with_grad):
        """
        Convert the input individual_RER into torch tensors.
        """
        # Momenta.
        momenta = individual_RER['momenta']
        momenta = torch.from_numpy(momenta).requires_grad_(with_grad).type(
            Settings().tensor_scalar_type)
        return momenta

    ####################################################################################################################
    ### Printing and writing methods:
    ####################################################################################################################

    def print(self, individual_RER):
        pass

    def write(self,
              dataset,
              population_RER,
              individual_RER,
              update_fixed_effects=True,
              write_residuals=True):

        # Write the model predictions, and compute the residuals at the same time.
        residuals = self._write_model_predictions(
            dataset,
            individual_RER,
            compute_residuals=(update_fixed_effects or write_residuals))

        # Optionally update the fixed effects.
        if update_fixed_effects:
            sufficient_statistics = self.compute_sufficient_statistics(
                dataset, population_RER, individual_RER, residuals=residuals)
            self.update_fixed_effects(dataset, sufficient_statistics)

        # Write residuals.
        if write_residuals:
            residuals_list = [[
                residuals_i_k.detach().cpu().numpy()
                for residuals_i_k in residuals_i
            ] for residuals_i in residuals]
            write_2D_list(residuals_list,
                          self.name + "__EstimatedParameters__Residuals.txt")

        # Write the model parameters.
        self._write_model_parameters(individual_RER)

    def _write_model_predictions(self,
                                 dataset,
                                 individual_RER,
                                 compute_residuals=True):

        # Initialize.
        template_data, template_points, control_points = self._fixed_effects_to_torch_tensors(
            False)
        momenta = self._individual_RER_to_torch_tensors(individual_RER, False)

        # Deform, write reconstructions and compute residuals.
        self.exponential.set_initial_template_points(template_points)
        self.exponential.set_initial_control_points(control_points)

        residuals = []  # List of torch 1D tensors. Individuals, objects.
        for i, subject_id in enumerate(dataset.subject_ids):
            self.exponential.set_initial_momenta(momenta[i])
            self.exponential.update()

            deformed_points = self.exponential.get_template_points()
            deformed_data = self.template.get_deformed_data(
                deformed_points, template_data)

            if compute_residuals:
                residuals.append(
                    self.multi_object_attachment.compute_distances(
                        deformed_data, self.template,
                        dataset.deformable_objects[i][0]))

            names = []
            for k, (object_name, object_extension) \
                    in enumerate(zip(self.objects_name, self.objects_name_extension)):
                name = self.name + '__Reconstruction__' + object_name + '__subject_' + subject_id + object_extension
                names.append(name)
            self.template.write(
                names, {
                    key: value.data.cpu().numpy()
                    for key, value in deformed_data.items()
                })

        return residuals

    def _write_model_parameters(self, individual_RER):
        # Template.
        template_names = []
        for i in range(len(self.objects_name)):
            aux = self.name + "__EstimatedParameters__Template_" + self.objects_name[
                i] + self.objects_name_extension[i]
            template_names.append(aux)
        self.template.write(template_names)

        # Control points.
        write_2D_array(self.get_control_points(),
                       self.name + "__EstimatedParameters__ControlPoints.txt")

        # Momenta.
        write_3D_array(individual_RER['momenta'],
                       self.name + "__EstimatedParameters__Momenta.txt")

        # Momenta covariance.
        write_2D_array(
            self.get_covariance_momenta_inverse(),
            self.name + "__EstimatedParameters__CovarianceMomentaInverse.txt")

        # Noise variance.
        write_2D_array(np.sqrt(self.get_noise_variance()),
                       self.name + "__EstimatedParameters__NoiseStd.txt")
コード例 #14
0
def _exp_parallelize(control_points, initial_momenta, projected_momenta,
                     xml_parameters):
    objects_list, objects_name, objects_name_extension, _, _ = create_template_metadata(
        xml_parameters.template_specifications)
    template = DeformableMultiObject()
    template.object_list = objects_list
    template.update()

    template_data = template.get_points()
    template_data_torch = Variable(
        torch.from_numpy(template_data).type(Settings().tensor_scalar_type))

    geodesic = Geodesic()
    geodesic.concentration_of_time_points = xml_parameters.concentration_of_time_points
    geodesic.set_kernel(
        kernel_factory.factory(xml_parameters.deformation_kernel_type,
                               xml_parameters.deformation_kernel_width))
    geodesic.set_use_rk2(xml_parameters.use_rk2)

    # Those are mandatory parameters.
    assert xml_parameters.tmin != -float(
        "inf"), "Please specify a minimum time for the geodesic trajectory"
    assert xml_parameters.tmax != float(
        "inf"), "Please specify a maximum time for the geodesic trajectory"

    geodesic.tmin = xml_parameters.tmin
    geodesic.tmax = xml_parameters.tmax
    if xml_parameters.t0 is None:
        geodesic.t0 = geodesic.tmin
    else:
        geodesic.t0 = xml_parameters.t0

    geodesic.set_momenta_t0(initial_momenta)
    geodesic.set_control_points_t0(control_points)
    geodesic.set_template_data_t0(template_data_torch)
    geodesic.update()

    # We write the flow of the geodesic

    geodesic.write("Regression", objects_name, objects_name_extension,
                   template)

    # Now we transport!
    parallel_transport_trajectory = geodesic.parallel_transport(
        projected_momenta)

    # Getting trajectory caracteristics:
    times = geodesic._get_times()
    control_points_traj = geodesic._get_control_points_trajectory()
    momenta_traj = geodesic._get_momenta_trajectory()
    template_data_traj = geodesic._get_template_data_trajectory()

    exponential = Exponential()
    exponential.number_of_time_points = xml_parameters.number_of_time_points
    exponential.set_kernel(
        kernel_factory.factory(xml_parameters.deformation_kernel_type,
                               xml_parameters.deformation_kernel_width))
    exponential.set_use_rk2(xml_parameters.use_rk2)

    # We save this trajectory, and the corresponding shape trajectory
    for i, (time, cp, mom, transported_mom, td) in enumerate(
            zip(times, control_points_traj, momenta_traj,
                parallel_transport_trajectory, template_data_traj)):
        # Writing the momenta/cps
        write_2D_array(
            cp.data.numpy(),
            "control_Points_tp_" + str(i) + "__age_" + str(time) + ".txt")
        write_3D_array(mom.data.numpy(),
                       "momenta_tp_" + str(i) + "__age_" + str(time) + ".txt")
        write_3D_array(
            transported_mom.data.numpy(),
            "transported_momenta_tp_" + str(i) + "__age_" + str(time) + ".txt")

        # Shooting from the geodesic:
        exponential.set_initial_template_data(td)
        exponential.set_initial_control_points(cp)
        exponential.set_initial_momenta(transported_mom)
        exponential.update()

        # Uncomment for massive writing, useful for debugging.
        # dir = "exp_"+str(i)+"_"+str(time)
        # if not(os.path.isdir(os.path.join(Settings().output_dir, dir))):
        #     os.mkdir(os.path.join(Settings().output_dir, dir))
        # exponential.write_flow([os.path.join(dir, elt) for elt in objects_name],
        #                        objects_name_extension,
        #                        template)
        # exponential.write_control_points_and_momenta_flow(os.path.join(dir, "cp_and_mom"))

        parallel_td = exponential.get_template_data()
        template.set_points(parallel_td)
        names = [
            objects_name[k] + "_parallel_curve_tp_" + str(i) + "__age_" +
            str(time) + "_" + objects_name_extension[k]
            for k in range(len(objects_name))
        ]
        template.write(names)