示例#1
0
def initialize_spatiotemporal_reference_frame(model,
                                              xml_parameters,
                                              dataset,
                                              observation_type='image'):
    """
    Initialize everything which is relative to the geodesic its parameters.
    """
    assert xml_parameters.dimension is not None, "Provide a dimension for the longitudinal metric learning atlas."

    exponential_factory = ExponentialFactory()
    if xml_parameters.exponential_type is not None:
        print("Initializing exponential type to",
              xml_parameters.exponential_type)
        exponential_factory.set_manifold_type(xml_parameters.exponential_type)
    else:
        msg = "Defaulting exponential type to parametric"
        warnings.warn(msg)

    # Reading parameter file, if there is one:
    metric_parameters = None
    if xml_parameters.metric_parameters_file is not None:
        print("Loading metric parameters from file",
              xml_parameters.metric_parameters_file)
        metric_parameters = np.loadtxt(xml_parameters.metric_parameters_file)

    # Initial metric parameters
    if exponential_factory.manifold_type == 'parametric':
        metric_parameters = _initialize_parametric_exponential(
            model, xml_parameters, dataset, exponential_factory,
            metric_parameters)

    if exponential_factory.manifold_type == 'deep':
        manifold_parameters = {
            'latent_space_dimension': xml_parameters.latent_space_dimension
        }
        exponential_factory.set_parameters(manifold_parameters)

    elif exponential_factory.manifold_type == 'logistic':
        """ 
        No initial parameter to set ! Just freeze the model parameters (or even delete the key ?)
        """
        model.is_frozen['metric_parameters'] = True

    model.spatiotemporal_reference_frame = GenericSpatiotemporalReferenceFrame(
        exponential_factory)
    model.spatiotemporal_reference_frame.set_concentration_of_time_points(
        xml_parameters.concentration_of_time_points)
    model.spatiotemporal_reference_frame.set_number_of_time_points(
        xml_parameters.number_of_time_points)
    model.parametric_metric = (xml_parameters.exponential_type
                               in ['parametric'])

    if xml_parameters.exponential_type == 'deep':
        model.deep_metric_learning = True
        model.latent_space_dimension = xml_parameters.latent_space_dimension
        model.initialize_deep_metric_learning()
        model.set_metric_parameters(metric_parameters)

    if xml_parameters.exponential_type == 'parametric':
        model.is_frozen[
            'metric_parameters'] = xml_parameters.freeze_metric_parameters
        model.set_metric_parameters(metric_parameters)

    if Settings().dimension == 1:
        print(
            "I am setting the no_parallel_transport flag to True because the dimension is 1"
        )
        model.no_parallel_transport = True
        model.spatiotemporal_reference_frame.no_parallel_transport = True
        model.number_of_sources = 0

    elif xml_parameters.number_of_sources == 0 or xml_parameters.number_of_sources is None:
        print(
            "I am setting the no_parallel_transport flag to True because the number of sources is 0."
        )
        model.no_parallel_transport = True
        model.spatiotemporal_reference_frame.no_parallel_transport = True
        model.number_of_sources = 0

    else:
        print("I am setting the no_parallel_transport flag to False.")
        model.no_parallel_transport = False
        model.spatiotemporal_reference_frame.no_parallel_transport = False
        model.number_of_sources = xml_parameters.number_of_sources
示例#2
0
    def update(self):
        """
        Runs the gradient ascent algorithm and updates the statistical model.
        """

        # Initialisation -----------------------------------------------------------------------------------------------
        # First case: we use the initialization stored in the state file
        if Settings().load_state:
            self.current_parameters, self.current_iteration = self._load_state_file(
            )
            self._set_parameters(
                self.current_parameters)  # Propagate the parameter values.
            logger.info("State file loaded, it was at iteration",
                        self.current_iteration)

        # Second case: we use the native initialization of the model.
        else:
            self.current_parameters = self._get_parameters()
            self.current_iteration = 0

        # Uncomment for a check of the gradient for the model !
        # WARNING: don't forget to comment the update_fixed_effects method of the model !
        # print("Checking the model gradient:")
        # self._check_model_gradient()

        self.current_attachment, self.current_regularity, gradient = self._evaluate_model_fit(
            self.current_parameters, with_grad=True)
        self.current_log_likelihood = self.current_attachment + self.current_regularity
        self.print()

        initial_log_likelihood = self.current_log_likelihood
        last_log_likelihood = initial_log_likelihood

        nb_params = len(gradient)
        self.step = self._initialize_step_size(gradient)

        # Main loop ----------------------------------------------------------------------------------------------------
        while self.current_iteration < self.max_iterations:
            self.current_iteration += 1

            # Line search ----------------------------------------------------------------------------------------------
            found_min = False
            for li in range(self.max_line_search_iterations):

                # Print step size --------------------------------------------------------------------------------------
                if not (self.current_iteration % self.print_every_n_iters):
                    logger.debug('Step size and gradient squared norm: ')
                    for key in gradient.keys():
                        logger.debug(
                            '\t\t%.3E   and   %.3E \t[ %s ]' %
                            (Decimal(str(self.step[key])),
                             Decimal(str(np.sum(gradient[key]**2))), key))

                # Try a simple gradient ascent step --------------------------------------------------------------------
                new_parameters = self._gradient_ascent_step(
                    self.current_parameters, gradient, self.step)
                new_attachment, new_regularity = self._evaluate_model_fit(
                    new_parameters)

                q = new_attachment + new_regularity - last_log_likelihood
                if q > 0:
                    found_min = True
                    self.step = {
                        key: value * self.line_search_expand
                        for key, value in self.step.items()
                    }
                    break

                # Adapting the step sizes ------------------------------------------------------------------------------
                self.step = {
                    key: value * self.line_search_shrink
                    for key, value in self.step.items()
                }
                if nb_params > 1:
                    new_parameters_prop = {}
                    new_attachment_prop = {}
                    new_regularity_prop = {}
                    q_prop = {}

                    for key in self.step.keys():
                        local_step = self.step.copy()
                        local_step[key] /= self.line_search_shrink

                        new_parameters_prop[key] = self._gradient_ascent_step(
                            self.current_parameters, gradient, local_step)
                        new_attachment_prop[key], new_regularity_prop[
                            key] = self._evaluate_model_fit(
                                new_parameters_prop[key])

                        q_prop[key] = new_attachment_prop[
                            key] + new_regularity_prop[
                                key] - last_log_likelihood

                    key_max = max(q_prop.keys(), key=(lambda key: q_prop[key]))
                    if q_prop[key_max] > 0:
                        new_attachment = new_attachment_prop[key_max]
                        new_regularity = new_regularity_prop[key_max]
                        new_parameters = new_parameters_prop[key_max]
                        self.step[key_max] /= self.line_search_shrink
                        found_min = True
                        break

            # End of line search ---------------------------------------------------------------------------------------
            if not found_min:
                self._set_parameters(self.current_parameters)
                logger.info('Number of line search loops exceeded. Stopping.')
                break

            self.current_attachment = new_attachment
            self.current_regularity = new_regularity
            self.current_log_likelihood = new_attachment + new_regularity
            self.current_parameters = new_parameters
            self._set_parameters(self.current_parameters)

            # Test the stopping criterion ------------------------------------------------------------------------------
            current_log_likelihood = self.current_log_likelihood
            delta_f_current = last_log_likelihood - current_log_likelihood
            delta_f_initial = initial_log_likelihood - current_log_likelihood

            if math.fabs(
                    delta_f_current
            ) < self.convergence_tolerance * math.fabs(delta_f_initial):
                logger.info(
                    'Tolerance threshold met. Stopping the optimization process.'
                )
                break

            # Printing and writing -------------------------------------------------------------------------------------
            if not self.current_iteration % self.print_every_n_iters:
                self.print()
            if not self.current_iteration % self.save_every_n_iters:
                self.write()

            # Prepare next iteration -----------------------------------------------------------------------------------
            last_log_likelihood = current_log_likelihood
            if self.current_iteration != self.max_iterations - 1:
                gradient = self._evaluate_model_fit(self.current_parameters,
                                                    with_grad=True)[2]

            # Save the state.
            if not self.current_iteration % self.save_every_n_iters:
                self._dump_state_file()
示例#3
0
 def _load_state_file(self):
     """
     loads Settings().state_file and returns what's necessary to restart the scipy optimization.
     """
     d = pickle.load(open(Settings().state_file, 'rb'))
     return d['parameters'], d['current_iteration'], d['parameters_shape'], d['parameters_order']
示例#4
0
 def _load_state_file(self):
     d = pickle.load(open(Settings().state_file, 'rb'))
     return d['current_parameters'], d['current_iteration']
示例#5
0
 def _dump_state_file(self):
     d = {
         'current_parameters': self.current_parameters,
         'current_iteration': self.current_iteration
     }
     pickle.dump(d, open(Settings().state_file, 'wb'))
示例#6
0
    def extend(self, number_of_additional_time_points):

        # Special case of the exponential reduced to a single point.
        if self.number_of_time_points == 1:
            self.number_of_time_points += number_of_additional_time_points
            self.update()
            return

        # Extended shoot.
        dt = 1.0 / float(self.number_of_time_points - 1)  # Same time-step.
        for i in range(number_of_additional_time_points):
            if self.use_rk2:
                new_cp, new_mom = self._rk2_step(self.control_points_t[-1], self.momenta_t[-1], dt, return_mom=True)
            else:
                new_cp, new_mom = self._euler_step(self.control_points_t[-1], self.momenta_t[-1], dt)

            self.control_points_t.append(new_cp)
            self.momenta_t.append(new_mom)

        # Scaling of the new length.
        length_ratio = float(self.number_of_time_points + number_of_additional_time_points - 1) \
                       / float(self.number_of_time_points - 1)
        self.number_of_time_points += number_of_additional_time_points
        self.initial_momenta = self.initial_momenta * length_ratio
        self.momenta_t = [elt * length_ratio for elt in self.momenta_t]
        self.norm_squared = self.norm_squared * length_ratio ** 2

        # Extended flow.
        # Special case of the dense mode.
        if Settings().dense_mode:
            assert 'image_points' not in self.initial_template_points.keys(), 'Dense mode not allowed with image data.'
            self.template_points_t['landmark_points'] = self.control_points_t
            return

        # Standard case.
        # Flow landmark points.
        if 'landmark_points' in self.initial_template_points.keys():
            for ii in range(number_of_additional_time_points):
                i = len(self.template_points_t['landmark_points']) - 1
                d_pos = self.kernel.convolve(
                    self.template_points_t['landmark_points'][i], self.control_points_t[i], self.momenta_t[i])
                self.template_points_t['landmark_points'].append(
                    self.template_points_t['landmark_points'][i] + dt * d_pos)

                if self.use_rk2:
                    # In this case improved euler (= Heun's method) to save one computation of convolve gradient.
                    self.template_points_t['landmark_points'][i + 1] = \
                        self.template_points_t['landmark_points'][i] + dt / 2 * (self.kernel.convolve(
                            self.template_points_t['landmark_points'][i + 1],
                            self.control_points_t[i + 1], self.momenta_t[i + 1]) + d_pos)

        # Flow image points.
        if 'image_points' in self.initial_template_points.keys():
            dimension = Settings().dimension
            image_shape = self.initial_template_points['image_points'].size()

            for ii in range(number_of_additional_time_points):
                i = len(self.template_points_t['image_points']) - 1
                vf = self.kernel.convolve(self.initial_template_points['image_points'].contiguous().view(-1, dimension),
                                          self.control_points_t[i], self.momenta_t[i]).view(image_shape)
                dY = self._compute_image_explicit_euler_step_at_order_1(self.template_points_t['image_points'][i], vf)
                self.template_points_t['image_points'].append(self.template_points_t['image_points'][i] - dY)

            if self.use_rk2:
                msg = 'RK2 not implemented to flow image points.'
                warnings.warn(msg)
示例#7
0
    def __init__(self, kernel_width=None):
        self.kernel_type = 'keops'
        super().__init__(kernel_width)

        self.gaussian_convolve = generic_sum(
            "Exp(-G*SqDist(X,Y)) * P",
            "O = Vx(" + str(Settings().dimension) + ")", "G = Pm(1)",
            "X = Vx(" + str(Settings().dimension) + ")",
            "Y = Vy(" + str(Settings().dimension) + ")",
            "P = Vy(" + str(Settings().dimension) + ")")

        self.varifold_convolve = generic_sum(
            "Exp(-(WeightedSqDist(G, X, Y))) * Pow((Nx, Ny), 2) * P",
            "O = Vx(1)", "G = Pm(1)",
            "X = Vx(" + str(Settings().dimension) + ")",
            "Y = Vy(" + str(Settings().dimension) + ")",
            "Nx = Vx(" + str(Settings().dimension) + ")",
            "Ny = Vy(" + str(Settings().dimension) + ")", "P = Vy(1)")

        self.gaussian_convolve_gradient_x = generic_sum(
            "(Px, Py) * Exp(-G*SqDist(X,Y)) * (X-Y) * ",
            "O = Vx(" + str(Settings().dimension) + ")", "G = Pm(1)",
            "X = Vx(" + str(Settings().dimension) + ")",
            "Y = Vy(" + str(Settings().dimension) + ")",
            "Px = Vx(" + str(Settings().dimension) + ")",
            "Py = Vy(" + str(Settings().dimension) + ")")
示例#8
0
    def _read_model_xml(self, model_xml_path):

        model_xml_level0 = et.parse(model_xml_path).getroot()

        for model_xml_level1 in model_xml_level0:

            if model_xml_level1.tag.lower() == 'model-type':
                self.model_type = model_xml_level1.text.lower()

            elif model_xml_level1.tag.lower() == 'dimension':
                self.dimension = int(model_xml_level1.text)
                Settings().dimension = self.dimension

            elif model_xml_level1.tag.lower() == 'initial-control-points':
                self.initial_control_points = os.path.normpath(
                    os.path.join(os.path.dirname(model_xml_path),
                                 model_xml_level1.text))

            elif model_xml_level1.tag.lower() == 'initial-momenta':
                self.initial_momenta = os.path.normpath(
                    os.path.join(os.path.dirname(model_xml_path),
                                 model_xml_level1.text))

            elif model_xml_level1.tag.lower() == 'initial-modulation-matrix':
                self.initial_modulation_matrix = os.path.normpath(
                    os.path.join(os.path.dirname(model_xml_path),
                                 model_xml_level1.text))

            elif model_xml_level1.tag.lower() == 'initial-time-shift-std':
                self.initial_time_shift_variance = float(
                    model_xml_level1.text)**2

            elif model_xml_level1.tag.lower(
            ) == 'initial-log-acceleration-std':
                self.initial_log_acceleration_variance = float(
                    model_xml_level1.text)**2

            elif model_xml_level1.tag.lower(
            ) == 'initial-log-acceleration-mean':
                self.initial_log_acceleration_mean = float(
                    model_xml_level1.text)

            elif model_xml_level1.tag.lower() == 'initial-onset-ages':
                self.initial_onset_ages = os.path.normpath(
                    os.path.join(os.path.dirname(model_xml_path),
                                 model_xml_level1.text))

            elif model_xml_level1.tag.lower() == 'initial-log-accelerations':
                self.initial_log_accelerations = os.path.normpath(
                    os.path.join(os.path.dirname(model_xml_path),
                                 model_xml_level1.text))

            elif model_xml_level1.tag.lower() == 'initial-sources':
                self.initial_sources = os.path.normpath(
                    os.path.join(os.path.dirname(model_xml_path),
                                 model_xml_level1.text))

            elif model_xml_level1.tag.lower() == 'initial-sources-mean':
                self.initial_sources_mean = model_xml_level1.text

            elif model_xml_level1.tag.lower() == 'initial-sources-std':
                self.initial_sources_std = model_xml_level1.text

            elif model_xml_level1.tag.lower(
            ) == 'initial-momenta-to-transport':
                self.initial_momenta_to_transport = os.path.normpath(
                    os.path.join(os.path.dirname(model_xml_path),
                                 model_xml_level1.text))

            elif model_xml_level1.tag.lower(
            ) == 'initial-control-points-to-transport':
                self.initial_control_points_to_transport = os.path.normpath(
                    os.path.join(os.path.dirname(model_xml_path),
                                 model_xml_level1.text))

            elif model_xml_level1.tag.lower() == 'initial-noise-std':
                self.initial_noise_variance = float(model_xml_level1.text)**2

            elif model_xml_level1.tag.lower() == 'latent-space-dimension':
                self.latent_space_dimension = int(model_xml_level1.text)

            elif model_xml_level1.tag.lower() == 'template':
                for model_xml_level2 in model_xml_level1:

                    if model_xml_level2.tag.lower() == 'dense-mode':
                        self.dense_mode = self._on_off_to_bool(
                            model_xml_level2.text)

                    elif model_xml_level2.tag.lower() == 'object':

                        template_object = self._initialize_template_object_xml_parameters(
                        )
                        for model_xml_level3 in model_xml_level2:
                            if model_xml_level3.tag.lower(
                            ) == 'deformable-object-type':
                                template_object[
                                    'deformable_object_type'] = model_xml_level3.text.lower(
                                    )
                            elif model_xml_level3.tag.lower(
                            ) == 'attachment-type':
                                template_object[
                                    'attachment_type'] = model_xml_level3.text.lower(
                                    )
                            elif model_xml_level3.tag.lower(
                            ) == 'kernel-width':
                                template_object['kernel_width'] = float(
                                    model_xml_level3.text)
                            elif model_xml_level3.tag.lower() == 'kernel-type':
                                template_object[
                                    'kernel_type'] = model_xml_level3.text.lower(
                                    )
                                if model_xml_level3.text.lower(
                                ) == 'keops'.lower():
                                    self._cuda_is_used = True
                            elif model_xml_level3.tag.lower() == 'noise-std':
                                template_object['noise_std'] = float(
                                    model_xml_level3.text)
                            elif model_xml_level3.tag.lower() == 'filename':
                                template_object['filename'] = os.path.normpath(
                                    os.path.join(
                                        os.path.dirname(model_xml_path),
                                        model_xml_level3.text))
                            elif model_xml_level3.tag.lower(
                            ) == 'noise-variance-prior-scale-std':
                                template_object[
                                    'noise_variance_prior_scale_std'] = float(
                                        model_xml_level3.text)
                            elif model_xml_level3.tag.lower(
                            ) == 'noise-variance-prior-normalized-dof':
                                template_object[
                                    'noise_variance_prior_normalized_dof'] = float(
                                        model_xml_level3.text)
                            else:
                                msg = 'Unknown entry while parsing the template > ' + model_xml_level2.attrib['id'] + \
                                      ' object section of the model xml: ' + model_xml_level3.tag
                                warnings.warn(msg)
                            self.template_specifications[
                                model_xml_level2.
                                attrib['id']] = template_object

                    else:
                        msg = 'Unknown entry while parsing the template section of the model xml: ' \
                              + model_xml_level2.tag
                        warnings.warn(msg)

            elif model_xml_level1.tag.lower() == 'deformation-parameters':
                for model_xml_level2 in model_xml_level1:
                    if model_xml_level2.tag.lower() == 'kernel-width':
                        self.deformation_kernel_width = float(
                            model_xml_level2.text)
                    elif model_xml_level2.tag.lower() == 'exponential-type':
                        self.exponential_type = model_xml_level2.text
                    elif model_xml_level2.tag.lower() == 'kernel-type':
                        self.deformation_kernel_type = model_xml_level2.text.lower(
                        )
                        if model_xml_level2.text.lower() == 'keops'.lower():
                            self._cuda_is_used = True
                    elif model_xml_level2.tag.lower(
                    ) == 'number-of-timepoints':
                        self.number_of_time_points = int(model_xml_level2.text)
                    elif model_xml_level2.tag.lower(
                    ) == 'number-of-interpolation-points':
                        self.number_of_interpolation_points = int(
                            model_xml_level2.text)
                    elif model_xml_level2.tag.lower(
                    ) == 'concentration-of-timepoints':
                        self.concentration_of_time_points = int(
                            model_xml_level2.text)
                    elif model_xml_level2.tag.lower() == 'number-of-sources':
                        self.number_of_sources = int(model_xml_level2.text)
                    elif model_xml_level2.tag.lower() == 't0':
                        self.t0 = float(model_xml_level2.text)
                    elif model_xml_level2.tag.lower() == 'tmin':
                        self.tmin = float(model_xml_level2.text)
                    elif model_xml_level2.tag.lower() == 'tmax':
                        self.tmax = float(model_xml_level2.text)
                    elif model_xml_level2.tag.lower() == 'p0':
                        self.p0 = model_xml_level2.text
                    elif model_xml_level2.tag.lower() == 'v0':
                        self.v0 = model_xml_level2.text
                    elif model_xml_level2.tag.lower(
                    ) == 'metric-parameters-file':  # for metric learning
                        self.metric_parameters_file = model_xml_level2.text
                    elif model_xml_level2.tag.lower(
                    ) == 'interpolation-points-file':  # for metric learning
                        self.interpolation_points_file = model_xml_level2.text
                    elif model_xml_level2.tag.lower(
                    ) == 'covariance-momenta-prior-normalized-dof':
                        self.covariance_momenta_prior_normalized_dof = float(
                            model_xml_level2.text)
                    else:
                        msg = 'Unknown entry while parsing the deformation-parameters section of the model xml: ' \
                              + model_xml_level2.tag
                        warnings.warn(msg)

            elif model_xml_level1.tag.lower() == 'use-exp-parallelization':
                self.use_exp_parallelization = self._on_off_to_bool(
                    model_xml_level1.text)

            else:
                msg = 'Unknown entry while parsing root of the model xml: ' + model_xml_level1.tag
                warnings.warn(msg)
示例#9
0
    def _further_initialization(self):

        if self.dense_mode:
            Settings().dense_mode = self.dense_mode
            print(
                '>> Dense mode activated. No distinction will be made between template and control points.'
            )
            assert len(self.template_specifications) == 1, \
                'Only a single object can be considered when using the dense mode.'
            if not self.freeze_control_points:
                self.freeze_control_points = True
                msg = 'With active dense mode, the freeze_template (currently %s) and freeze_control_points ' \
                      '(currently %s) flags are redundant. Defaulting to freeze_control_points = True.' \
                      % (str(self.freeze_template), str(self.freeze_control_points))
                warnings.warn(msg)
            if self.initial_control_points is not None:
                self.initial_control_points = None
                msg = 'With active dense mode, specifying initial_control_points is useless. Ignoring this xml entry.'
                warnings.warn(msg)

        if self.initial_cp_spacing < 0 and self.initial_control_points is None and not self.dense_mode:
            print(
                '>> No initial CP spacing given: using diffeo kernel width of '
                + str(self.deformation_kernel_width))
            self.initial_cp_spacing = self.deformation_kernel_width

        # Setting tensor types according to CUDA availability and user choices.
        if self._cuda_is_used:
            if not torch.cuda.is_available():
                msg = 'CUDA seems to be unavailable. All computations will be carried out on CPU.'
                warnings.warn(msg)
            else:
                print(
                    ">> CUDA is used at least in one operation, all operations will be done with FLOAT precision."
                )
                if self.use_cuda:
                    print(">> All tensors will be CUDA tensors.")
                    Settings().tensor_scalar_type = torch.cuda.FloatTensor
                    Settings().tensor_integer_type = torch.cuda.LongTensor
                else:
                    print(">> Setting tensor type to float.")
                    Settings().tensor_scalar_type = torch.FloatTensor

        # Setting the dimension.
        Settings().dimension = self.dimension

        # If longitudinal model and t0 is not initialized, initializes it.
        if (self.model_type == 'regression' or self.model_type == 'LongitudinalAtlas'.lower()
            or self.model_type == 'LongitudinalRegistration'.lower()) \
                and (self.t0 is None or self.initial_time_shift_variance is None):
            total_number_of_visits = 0
            mean_visit_age = 0.0
            var_visit_age = 0.0
            for i in range(len(self.visit_ages)):
                for j in range(len(self.visit_ages[i])):
                    total_number_of_visits += 1
                    mean_visit_age += self.visit_ages[i][j]
                    var_visit_age += self.visit_ages[i][j]**2

            if total_number_of_visits > 0:
                mean_visit_age /= float(total_number_of_visits)
                var_visit_age = (
                    var_visit_age / float(total_number_of_visits) -
                    mean_visit_age**2)

                if self.t0 is None:
                    print('>> Initial t0 set to the mean visit age: %.2f' %
                          mean_visit_age)
                    self.t0 = mean_visit_age
                else:
                    print(
                        '>> Initial t0 set by the user to %.2f ; note that the mean visit age is %.2f'
                        % (self.t0, mean_visit_age))

                if not self.model_type == 'regression':
                    if self.initial_time_shift_variance is None:
                        print(
                            '>> Initial time-shift std set to the empirical std of the visit ages: %.2f'
                            % math.sqrt(var_visit_age))
                        self.initial_time_shift_variance = var_visit_age
                    else:
                        print((
                            '>> Initial time-shift std set by the user to %.2f ; note that the empirical std of '
                            'the visit ages is %.2f') %
                              (self.initial_time_shift_variance,
                               math.sqrt(var_visit_age)))

        # Setting the number of threads in general settings
        Settings().number_of_threads = self.number_of_threads
        if self.number_of_threads > 1:
            print(
                ">> I will use", self.number_of_threads,
                "threads, and I set OMP_NUM_THREADS and torch_num_threads to 1."
            )
            os.environ['OMP_NUM_THREADS'] = "1"
            torch.set_num_threads(1)
        else:
            print('>> Setting OMP_NUM_THREADS and torch_num_threads to 4.')
            os.environ['OMP_NUM_THREADS'] = "4"
            torch.set_num_threads(4)

            try:
                set_start_method("spawn")
            except RuntimeError as error:
                print('>> Warning: ' + str(error) +
                      ' [ in xml_parameters ]. Ignoring.')

        self._initialize_state_file()

        # Freeze the fixed effects in case of a registration.
        if self.model_type == 'Registration'.lower():
            self.freeze_template = True
            self.freeze_control_points = True

        elif self.model_type == 'LongitudinalRegistration'.lower():
            self.freeze_template = True
            self.freeze_control_points = True
            self.freeze_momenta = True
            self.freeze_modulation_matrix = True
            self.freeze_reference_time = True
            self.freeze_time_shift_variance = True
            self.freeze_log_acceleration_variance = True
            self.freeze_noise_variance = True

        # Initialize the number of sources if needed.
        if self.model_type == 'LongitudinalAtlas'.lower() \
                and self.initial_modulation_matrix is None and self.number_of_sources is None:
            self.number_of_sources = 4
            print(
                '>> No initial modulation matrix given, neither a number of sources. '
                'The latter will be ARBITRARILY defaulted to 4.')

        if self.dimension <= 1:
            print(
                "Setting the number of sources to 0 because the dimension is 1."
            )
            self.number_of_sources = 0

        # Initialize the initial_log_acceleration_variance if needed.
        if (self.model_type == 'LongitudinalAtlas'.lower() or self.model_type == 'LongitudinalRegistration'.lower()) \
                and self.initial_log_acceleration_variance is None:
            print(
                '>> The initial log-acceleration std fixed effect is ARBITRARILY set to 0.5'
            )
            log_acceleration_std = 0.5
            self.initial_log_acceleration_variance = (log_acceleration_std**2)

        # Image grid downsampling factor.
        if not self.downsampling_factor == 1:
            image_object_specs = [
                (key, value)
                for key, value in self.template_specifications.items()
                if value['deformable_object_type'].lower() == 'image'
            ]
            if len(image_object_specs) > 2:
                raise RuntimeError('Only a single image object can be used.')
            elif len(image_object_specs) == 1:
                print('>> Setting the image grid downsampling factor to: %d.' %
                      self.downsampling_factor)
                self.template_specifications[image_object_specs[0][0]][
                    'downsampling_factor'] = self.downsampling_factor
            else:
                msg = 'The "downsampling_factor" parameter is useful only for image data, ' \
                      'but none is considered here. Ignoring.'
                warnings.warn(msg)
    def update(self):
        """
        Update the geodesic, and compute the parallel transport of each column of the modulation matrix along
        this geodesic, ignoring the tangential components.
        """
        # Update the geodesic.
        self.geodesic.update()

        # Convenient attributes for later use.
        self.times = self.geodesic._get_times()
        self.template_points_t = self.geodesic._get_template_points_trajectory(
        )
        self.control_points_t = self.geodesic._get_control_points_trajectory()

        if self.transport_is_modified:
            # Initializes the projected_modulation_matrix_t attribute size.
            self.projected_modulation_matrix_t = \
                [Variable(torch.zeros(self.modulation_matrix_t0.size()).type(Settings().tensor_scalar_type),
                          requires_grad=False) for _ in range(len(self.control_points_t))]

            # Transport each column, ignoring the tangential components.
            for s in range(self.number_of_sources):
                space_shift_t0 = self.modulation_matrix_t0[:, s].contiguous(
                ).view(self.geodesic.momenta_t0.size())
                space_shift_t = self.geodesic.parallel_transport(
                    space_shift_t0, is_orthogonal=True)

                # Set the result correctly in the projected_modulation_matrix_t attribute.
                for t, space_shift in enumerate(space_shift_t):
                    self.projected_modulation_matrix_t[
                        t][:, s] = space_shift.view(-1)

            self.transport_is_modified = False
            self.backward_extension = 0
            self.forward_extension = 0

        elif self.backward_extension > 0 or self.forward_extension > 0:

            # Initializes the extended projected_modulation_matrix_t variable.
            projected_modulation_matrix_t_extended = \
                [Variable(torch.zeros(self.modulation_matrix_t0.size()).type(Settings().tensor_scalar_type),
                          requires_grad=False) for _ in range(len(self.control_points_t))]

            # Transport each column, ignoring the tangential components.
            for s in range(self.number_of_sources):
                space_shift_t = [
                    elt[:,
                        s].contiguous().view(self.geodesic.momenta_t0.size())
                    for elt in self.projected_modulation_matrix_t
                ]
                # print(len(self.control_points_t))
                space_shift_t = self.geodesic.extend_parallel_transport(
                    space_shift_t,
                    self.backward_extension,
                    self.forward_extension,
                    is_orthogonal=True)

                for t, space_shift in enumerate(space_shift_t):
                    projected_modulation_matrix_t_extended[
                        t][:, s] = space_shift.view(-1)

            self.projected_modulation_matrix_t = projected_modulation_matrix_t_extended
            self.backward_extension = 0
            self.forward_extension = 0

        assert len(self.template_points_t[list(self.template_points_t.keys())[0]]) == len(self.control_points_t) \
                == len(self.times) == len(self.projected_modulation_matrix_t), \
            "That's weird: len(self.template_points_t[list(self.template_points_t.keys())[0]]) = %d, " \
            "len(self.control_points_t) = %d, len(self.times) = %d,  len(self.projected_modulation_matrix_t) = %d" % \
            (len(self.template_points_t[list(self.template_points_t.keys())[0]]), len(self.control_points_t),
             len(self.times), len(self.projected_modulation_matrix_t))
示例#11
0
def create_regular_grid_of_points(box, spacing):
    """
    Creates a regular grid of 2D or 3D points, as a numpy array of size nb_of_points x dimension.
    box: (dimension, 2)
    """

    dimension = Settings().dimension

    axis = []
    for d in range(dimension):
        min = box[d, 0]
        max = box[d, 1]
        length = max - min
        assert (length > 0)

        offset = 0.5 * (length - spacing * math.floor(length / spacing))
        axis.append(np.arange(min + offset, max + 1e-10, spacing))

    if dimension == 1:
        control_points = np.zeros((len(axis[0]), dimension))
        control_points[:, 0] = axis[0].flatten()

    elif dimension == 2:
        x_axis, y_axis = np.meshgrid(axis[0], axis[1])

        assert (x_axis.shape == y_axis.shape)
        number_of_control_points = x_axis.flatten().shape[0]
        control_points = np.zeros((number_of_control_points, dimension))

        control_points[:, 0] = x_axis.flatten()
        control_points[:, 1] = y_axis.flatten()

    elif dimension == 3:
        x_axis, y_axis, z_axis = np.meshgrid(axis[0], axis[1], axis[2])

        assert (x_axis.shape == y_axis.shape)
        assert (x_axis.shape == z_axis.shape)
        number_of_control_points = x_axis.flatten().shape[0]
        control_points = np.zeros((number_of_control_points, dimension))

        control_points[:, 0] = x_axis.flatten()
        control_points[:, 1] = y_axis.flatten()
        control_points[:, 2] = z_axis.flatten()

    elif dimension == 4:
        x_axis, y_axis, z_axis, t_axis = np.meshgrid(axis[0], axis[1], axis[2],
                                                     axis[3])

        assert (x_axis.shape == y_axis.shape)
        assert (x_axis.shape == z_axis.shape)
        number_of_control_points = x_axis.flatten().shape[0]
        control_points = np.zeros((number_of_control_points, dimension))

        control_points[:, 0] = x_axis.flatten()
        control_points[:, 1] = y_axis.flatten()
        control_points[:, 2] = z_axis.flatten()
        control_points[:, 3] = t_axis.flatten()

    else:
        raise RuntimeError('Invalid ambient space dimension.')

    return control_points
示例#12
0
def main():
    import logging
    logger = logging.getLogger(__name__)
    logger_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'

    # parse arguments
    parser = argparse.ArgumentParser(description='Deformetrica')
    parser.add_argument('model', type=str, help='model xml file')
    parser.add_argument('dataset', type=str, help='data-set xml file')
    parser.add_argument('optimization',
                        type=str,
                        help='optimization parameters xml file')

    # optional arguments
    parser.add_argument('-o', '--output', type=str, help='output folder')
    # logging levels: https://docs.python.org/2/library/logging.html#logging-levels
    parser.add_argument(
        '--verbosity',
        '-v',
        type=str,
        default='WARNING',
        choices=['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
        help='set output verbosity')

    args = parser.parse_args()

    # set logging level
    try:
        log_level = logging.getLevelName(args.verbosity)
        logging.basicConfig(level=log_level, format=logger_format)
    except ValueError:
        logger.warning('Logging level was not recognized. Using INFO.')
        log_level = logging.INFO

    logger.debug('Using verbosity level: ' + args.verbosity)
    logging.basicConfig(level=log_level, format=logger_format)

    # Basic info printing
    logger.info(info())
    """
    Read xml files, set general settings, and call the adapted function.
    """
    try:
        if args.output is None:
            logger.info('Creating the output directory: ' +
                        Settings().output_dir)
            os.makedirs(Settings().output_dir)
        else:
            logger.info('Setting output directory to: ' + args.output)
            Settings().set_output_dir(args.output)
    except FileExistsError:
        pass

    logger.info('[ read_all_xmls function ]')
    xml_parameters = XmlParameters()
    xml_parameters.read_all_xmls(args.model, args.dataset, args.optimization)

    if xml_parameters.model_type == 'DeterministicAtlas'.lower() \
            or xml_parameters.model_type == 'Registration'.lower():
        estimate_deterministic_atlas(xml_parameters)

    elif xml_parameters.model_type == 'BayesianAtlas'.lower():
        estimate_bayesian_atlas(xml_parameters)

    elif xml_parameters.model_type == 'Regression'.lower():
        estimate_geodesic_regression(xml_parameters)

    elif xml_parameters.model_type == 'LongitudinalAtlas'.lower():
        estimate_longitudinal_atlas(xml_parameters)

    elif xml_parameters.model_type == 'LongitudinalRegistration'.lower():
        estimate_longitudinal_registration(xml_parameters)

    elif xml_parameters.model_type == 'Shooting'.lower():
        run_shooting(xml_parameters)

    elif xml_parameters.model_type == 'ParallelTransport'.lower():
        compute_parallel_transport(xml_parameters)

    elif xml_parameters.model_type == 'LongitudinalMetricLearning'.lower():
        estimate_longitudinal_metric_model(xml_parameters)

    elif xml_parameters.model_type == 'LongitudinalMetricRegistration'.lower():
        estimate_longitudinal_metric_registration(xml_parameters)

    else:
        raise RuntimeError(
            'Unrecognized model-type: "' + xml_parameters.model_type +
            '". Check the corresponding field in the model.xml input file.')
示例#13
0
 def set_connectivity(self, connectivity):
     self.connectivity = torch.from_numpy(connectivity).type(
         Settings().tensor_integer_type)
     self.is_modified = True
 def test_write_3D_array(self):
     momenta = read_3D_array(os.path.join(Settings().unit_tests_data_dir, "Momenta.txt"))
     write_3D_array(momenta, self.test_output_file_path)
     read = read_3D_array(self.test_output_file_path)
     self.assertTrue(np.allclose(momenta, read))
 def test_read_3D_array(self):
     momenta = read_3D_array(os.path.join(Settings().unit_tests_data_dir, "Momenta.txt"))
     self.assertEqual(momenta.shape, (4, 72, 3))
     self.assertTrue(np.allclose(momenta[0, 0], np.array([-0.0313538, -0.00373486, -0.0256917])))
     self.assertTrue(np.allclose(momenta[0, -1], np.array([-0.518624, 1.47211, 0.880905])))
     self.assertTrue(np.allclose(momenta[-1, -1], np.array([2.81286, -0.353167, -2.16408])))
示例#16
0
 def inverse_metric(self, q):
     return Variable(
         torch.eye(self.dimension).type(Settings().tensor_scalar_type))