def initialize_spatiotemporal_reference_frame(model, xml_parameters, dataset, observation_type='image'): """ Initialize everything which is relative to the geodesic its parameters. """ assert xml_parameters.dimension is not None, "Provide a dimension for the longitudinal metric learning atlas." exponential_factory = ExponentialFactory() if xml_parameters.exponential_type is not None: print("Initializing exponential type to", xml_parameters.exponential_type) exponential_factory.set_manifold_type(xml_parameters.exponential_type) else: msg = "Defaulting exponential type to parametric" warnings.warn(msg) # Reading parameter file, if there is one: metric_parameters = None if xml_parameters.metric_parameters_file is not None: print("Loading metric parameters from file", xml_parameters.metric_parameters_file) metric_parameters = np.loadtxt(xml_parameters.metric_parameters_file) # Initial metric parameters if exponential_factory.manifold_type == 'parametric': metric_parameters = _initialize_parametric_exponential( model, xml_parameters, dataset, exponential_factory, metric_parameters) if exponential_factory.manifold_type == 'deep': manifold_parameters = { 'latent_space_dimension': xml_parameters.latent_space_dimension } exponential_factory.set_parameters(manifold_parameters) elif exponential_factory.manifold_type == 'logistic': """ No initial parameter to set ! Just freeze the model parameters (or even delete the key ?) """ model.is_frozen['metric_parameters'] = True model.spatiotemporal_reference_frame = GenericSpatiotemporalReferenceFrame( exponential_factory) model.spatiotemporal_reference_frame.set_concentration_of_time_points( xml_parameters.concentration_of_time_points) model.spatiotemporal_reference_frame.set_number_of_time_points( xml_parameters.number_of_time_points) model.parametric_metric = (xml_parameters.exponential_type in ['parametric']) if xml_parameters.exponential_type == 'deep': model.deep_metric_learning = True model.latent_space_dimension = xml_parameters.latent_space_dimension model.initialize_deep_metric_learning() model.set_metric_parameters(metric_parameters) if xml_parameters.exponential_type == 'parametric': model.is_frozen[ 'metric_parameters'] = xml_parameters.freeze_metric_parameters model.set_metric_parameters(metric_parameters) if Settings().dimension == 1: print( "I am setting the no_parallel_transport flag to True because the dimension is 1" ) model.no_parallel_transport = True model.spatiotemporal_reference_frame.no_parallel_transport = True model.number_of_sources = 0 elif xml_parameters.number_of_sources == 0 or xml_parameters.number_of_sources is None: print( "I am setting the no_parallel_transport flag to True because the number of sources is 0." ) model.no_parallel_transport = True model.spatiotemporal_reference_frame.no_parallel_transport = True model.number_of_sources = 0 else: print("I am setting the no_parallel_transport flag to False.") model.no_parallel_transport = False model.spatiotemporal_reference_frame.no_parallel_transport = False model.number_of_sources = xml_parameters.number_of_sources
def update(self): """ Runs the gradient ascent algorithm and updates the statistical model. """ # Initialisation ----------------------------------------------------------------------------------------------- # First case: we use the initialization stored in the state file if Settings().load_state: self.current_parameters, self.current_iteration = self._load_state_file( ) self._set_parameters( self.current_parameters) # Propagate the parameter values. logger.info("State file loaded, it was at iteration", self.current_iteration) # Second case: we use the native initialization of the model. else: self.current_parameters = self._get_parameters() self.current_iteration = 0 # Uncomment for a check of the gradient for the model ! # WARNING: don't forget to comment the update_fixed_effects method of the model ! # print("Checking the model gradient:") # self._check_model_gradient() self.current_attachment, self.current_regularity, gradient = self._evaluate_model_fit( self.current_parameters, with_grad=True) self.current_log_likelihood = self.current_attachment + self.current_regularity self.print() initial_log_likelihood = self.current_log_likelihood last_log_likelihood = initial_log_likelihood nb_params = len(gradient) self.step = self._initialize_step_size(gradient) # Main loop ---------------------------------------------------------------------------------------------------- while self.current_iteration < self.max_iterations: self.current_iteration += 1 # Line search ---------------------------------------------------------------------------------------------- found_min = False for li in range(self.max_line_search_iterations): # Print step size -------------------------------------------------------------------------------------- if not (self.current_iteration % self.print_every_n_iters): logger.debug('Step size and gradient squared norm: ') for key in gradient.keys(): logger.debug( '\t\t%.3E and %.3E \t[ %s ]' % (Decimal(str(self.step[key])), Decimal(str(np.sum(gradient[key]**2))), key)) # Try a simple gradient ascent step -------------------------------------------------------------------- new_parameters = self._gradient_ascent_step( self.current_parameters, gradient, self.step) new_attachment, new_regularity = self._evaluate_model_fit( new_parameters) q = new_attachment + new_regularity - last_log_likelihood if q > 0: found_min = True self.step = { key: value * self.line_search_expand for key, value in self.step.items() } break # Adapting the step sizes ------------------------------------------------------------------------------ self.step = { key: value * self.line_search_shrink for key, value in self.step.items() } if nb_params > 1: new_parameters_prop = {} new_attachment_prop = {} new_regularity_prop = {} q_prop = {} for key in self.step.keys(): local_step = self.step.copy() local_step[key] /= self.line_search_shrink new_parameters_prop[key] = self._gradient_ascent_step( self.current_parameters, gradient, local_step) new_attachment_prop[key], new_regularity_prop[ key] = self._evaluate_model_fit( new_parameters_prop[key]) q_prop[key] = new_attachment_prop[ key] + new_regularity_prop[ key] - last_log_likelihood key_max = max(q_prop.keys(), key=(lambda key: q_prop[key])) if q_prop[key_max] > 0: new_attachment = new_attachment_prop[key_max] new_regularity = new_regularity_prop[key_max] new_parameters = new_parameters_prop[key_max] self.step[key_max] /= self.line_search_shrink found_min = True break # End of line search --------------------------------------------------------------------------------------- if not found_min: self._set_parameters(self.current_parameters) logger.info('Number of line search loops exceeded. Stopping.') break self.current_attachment = new_attachment self.current_regularity = new_regularity self.current_log_likelihood = new_attachment + new_regularity self.current_parameters = new_parameters self._set_parameters(self.current_parameters) # Test the stopping criterion ------------------------------------------------------------------------------ current_log_likelihood = self.current_log_likelihood delta_f_current = last_log_likelihood - current_log_likelihood delta_f_initial = initial_log_likelihood - current_log_likelihood if math.fabs( delta_f_current ) < self.convergence_tolerance * math.fabs(delta_f_initial): logger.info( 'Tolerance threshold met. Stopping the optimization process.' ) break # Printing and writing ------------------------------------------------------------------------------------- if not self.current_iteration % self.print_every_n_iters: self.print() if not self.current_iteration % self.save_every_n_iters: self.write() # Prepare next iteration ----------------------------------------------------------------------------------- last_log_likelihood = current_log_likelihood if self.current_iteration != self.max_iterations - 1: gradient = self._evaluate_model_fit(self.current_parameters, with_grad=True)[2] # Save the state. if not self.current_iteration % self.save_every_n_iters: self._dump_state_file()
def _load_state_file(self): """ loads Settings().state_file and returns what's necessary to restart the scipy optimization. """ d = pickle.load(open(Settings().state_file, 'rb')) return d['parameters'], d['current_iteration'], d['parameters_shape'], d['parameters_order']
def _load_state_file(self): d = pickle.load(open(Settings().state_file, 'rb')) return d['current_parameters'], d['current_iteration']
def _dump_state_file(self): d = { 'current_parameters': self.current_parameters, 'current_iteration': self.current_iteration } pickle.dump(d, open(Settings().state_file, 'wb'))
def extend(self, number_of_additional_time_points): # Special case of the exponential reduced to a single point. if self.number_of_time_points == 1: self.number_of_time_points += number_of_additional_time_points self.update() return # Extended shoot. dt = 1.0 / float(self.number_of_time_points - 1) # Same time-step. for i in range(number_of_additional_time_points): if self.use_rk2: new_cp, new_mom = self._rk2_step(self.control_points_t[-1], self.momenta_t[-1], dt, return_mom=True) else: new_cp, new_mom = self._euler_step(self.control_points_t[-1], self.momenta_t[-1], dt) self.control_points_t.append(new_cp) self.momenta_t.append(new_mom) # Scaling of the new length. length_ratio = float(self.number_of_time_points + number_of_additional_time_points - 1) \ / float(self.number_of_time_points - 1) self.number_of_time_points += number_of_additional_time_points self.initial_momenta = self.initial_momenta * length_ratio self.momenta_t = [elt * length_ratio for elt in self.momenta_t] self.norm_squared = self.norm_squared * length_ratio ** 2 # Extended flow. # Special case of the dense mode. if Settings().dense_mode: assert 'image_points' not in self.initial_template_points.keys(), 'Dense mode not allowed with image data.' self.template_points_t['landmark_points'] = self.control_points_t return # Standard case. # Flow landmark points. if 'landmark_points' in self.initial_template_points.keys(): for ii in range(number_of_additional_time_points): i = len(self.template_points_t['landmark_points']) - 1 d_pos = self.kernel.convolve( self.template_points_t['landmark_points'][i], self.control_points_t[i], self.momenta_t[i]) self.template_points_t['landmark_points'].append( self.template_points_t['landmark_points'][i] + dt * d_pos) if self.use_rk2: # In this case improved euler (= Heun's method) to save one computation of convolve gradient. self.template_points_t['landmark_points'][i + 1] = \ self.template_points_t['landmark_points'][i] + dt / 2 * (self.kernel.convolve( self.template_points_t['landmark_points'][i + 1], self.control_points_t[i + 1], self.momenta_t[i + 1]) + d_pos) # Flow image points. if 'image_points' in self.initial_template_points.keys(): dimension = Settings().dimension image_shape = self.initial_template_points['image_points'].size() for ii in range(number_of_additional_time_points): i = len(self.template_points_t['image_points']) - 1 vf = self.kernel.convolve(self.initial_template_points['image_points'].contiguous().view(-1, dimension), self.control_points_t[i], self.momenta_t[i]).view(image_shape) dY = self._compute_image_explicit_euler_step_at_order_1(self.template_points_t['image_points'][i], vf) self.template_points_t['image_points'].append(self.template_points_t['image_points'][i] - dY) if self.use_rk2: msg = 'RK2 not implemented to flow image points.' warnings.warn(msg)
def __init__(self, kernel_width=None): self.kernel_type = 'keops' super().__init__(kernel_width) self.gaussian_convolve = generic_sum( "Exp(-G*SqDist(X,Y)) * P", "O = Vx(" + str(Settings().dimension) + ")", "G = Pm(1)", "X = Vx(" + str(Settings().dimension) + ")", "Y = Vy(" + str(Settings().dimension) + ")", "P = Vy(" + str(Settings().dimension) + ")") self.varifold_convolve = generic_sum( "Exp(-(WeightedSqDist(G, X, Y))) * Pow((Nx, Ny), 2) * P", "O = Vx(1)", "G = Pm(1)", "X = Vx(" + str(Settings().dimension) + ")", "Y = Vy(" + str(Settings().dimension) + ")", "Nx = Vx(" + str(Settings().dimension) + ")", "Ny = Vy(" + str(Settings().dimension) + ")", "P = Vy(1)") self.gaussian_convolve_gradient_x = generic_sum( "(Px, Py) * Exp(-G*SqDist(X,Y)) * (X-Y) * ", "O = Vx(" + str(Settings().dimension) + ")", "G = Pm(1)", "X = Vx(" + str(Settings().dimension) + ")", "Y = Vy(" + str(Settings().dimension) + ")", "Px = Vx(" + str(Settings().dimension) + ")", "Py = Vy(" + str(Settings().dimension) + ")")
def _read_model_xml(self, model_xml_path): model_xml_level0 = et.parse(model_xml_path).getroot() for model_xml_level1 in model_xml_level0: if model_xml_level1.tag.lower() == 'model-type': self.model_type = model_xml_level1.text.lower() elif model_xml_level1.tag.lower() == 'dimension': self.dimension = int(model_xml_level1.text) Settings().dimension = self.dimension elif model_xml_level1.tag.lower() == 'initial-control-points': self.initial_control_points = os.path.normpath( os.path.join(os.path.dirname(model_xml_path), model_xml_level1.text)) elif model_xml_level1.tag.lower() == 'initial-momenta': self.initial_momenta = os.path.normpath( os.path.join(os.path.dirname(model_xml_path), model_xml_level1.text)) elif model_xml_level1.tag.lower() == 'initial-modulation-matrix': self.initial_modulation_matrix = os.path.normpath( os.path.join(os.path.dirname(model_xml_path), model_xml_level1.text)) elif model_xml_level1.tag.lower() == 'initial-time-shift-std': self.initial_time_shift_variance = float( model_xml_level1.text)**2 elif model_xml_level1.tag.lower( ) == 'initial-log-acceleration-std': self.initial_log_acceleration_variance = float( model_xml_level1.text)**2 elif model_xml_level1.tag.lower( ) == 'initial-log-acceleration-mean': self.initial_log_acceleration_mean = float( model_xml_level1.text) elif model_xml_level1.tag.lower() == 'initial-onset-ages': self.initial_onset_ages = os.path.normpath( os.path.join(os.path.dirname(model_xml_path), model_xml_level1.text)) elif model_xml_level1.tag.lower() == 'initial-log-accelerations': self.initial_log_accelerations = os.path.normpath( os.path.join(os.path.dirname(model_xml_path), model_xml_level1.text)) elif model_xml_level1.tag.lower() == 'initial-sources': self.initial_sources = os.path.normpath( os.path.join(os.path.dirname(model_xml_path), model_xml_level1.text)) elif model_xml_level1.tag.lower() == 'initial-sources-mean': self.initial_sources_mean = model_xml_level1.text elif model_xml_level1.tag.lower() == 'initial-sources-std': self.initial_sources_std = model_xml_level1.text elif model_xml_level1.tag.lower( ) == 'initial-momenta-to-transport': self.initial_momenta_to_transport = os.path.normpath( os.path.join(os.path.dirname(model_xml_path), model_xml_level1.text)) elif model_xml_level1.tag.lower( ) == 'initial-control-points-to-transport': self.initial_control_points_to_transport = os.path.normpath( os.path.join(os.path.dirname(model_xml_path), model_xml_level1.text)) elif model_xml_level1.tag.lower() == 'initial-noise-std': self.initial_noise_variance = float(model_xml_level1.text)**2 elif model_xml_level1.tag.lower() == 'latent-space-dimension': self.latent_space_dimension = int(model_xml_level1.text) elif model_xml_level1.tag.lower() == 'template': for model_xml_level2 in model_xml_level1: if model_xml_level2.tag.lower() == 'dense-mode': self.dense_mode = self._on_off_to_bool( model_xml_level2.text) elif model_xml_level2.tag.lower() == 'object': template_object = self._initialize_template_object_xml_parameters( ) for model_xml_level3 in model_xml_level2: if model_xml_level3.tag.lower( ) == 'deformable-object-type': template_object[ 'deformable_object_type'] = model_xml_level3.text.lower( ) elif model_xml_level3.tag.lower( ) == 'attachment-type': template_object[ 'attachment_type'] = model_xml_level3.text.lower( ) elif model_xml_level3.tag.lower( ) == 'kernel-width': template_object['kernel_width'] = float( model_xml_level3.text) elif model_xml_level3.tag.lower() == 'kernel-type': template_object[ 'kernel_type'] = model_xml_level3.text.lower( ) if model_xml_level3.text.lower( ) == 'keops'.lower(): self._cuda_is_used = True elif model_xml_level3.tag.lower() == 'noise-std': template_object['noise_std'] = float( model_xml_level3.text) elif model_xml_level3.tag.lower() == 'filename': template_object['filename'] = os.path.normpath( os.path.join( os.path.dirname(model_xml_path), model_xml_level3.text)) elif model_xml_level3.tag.lower( ) == 'noise-variance-prior-scale-std': template_object[ 'noise_variance_prior_scale_std'] = float( model_xml_level3.text) elif model_xml_level3.tag.lower( ) == 'noise-variance-prior-normalized-dof': template_object[ 'noise_variance_prior_normalized_dof'] = float( model_xml_level3.text) else: msg = 'Unknown entry while parsing the template > ' + model_xml_level2.attrib['id'] + \ ' object section of the model xml: ' + model_xml_level3.tag warnings.warn(msg) self.template_specifications[ model_xml_level2. attrib['id']] = template_object else: msg = 'Unknown entry while parsing the template section of the model xml: ' \ + model_xml_level2.tag warnings.warn(msg) elif model_xml_level1.tag.lower() == 'deformation-parameters': for model_xml_level2 in model_xml_level1: if model_xml_level2.tag.lower() == 'kernel-width': self.deformation_kernel_width = float( model_xml_level2.text) elif model_xml_level2.tag.lower() == 'exponential-type': self.exponential_type = model_xml_level2.text elif model_xml_level2.tag.lower() == 'kernel-type': self.deformation_kernel_type = model_xml_level2.text.lower( ) if model_xml_level2.text.lower() == 'keops'.lower(): self._cuda_is_used = True elif model_xml_level2.tag.lower( ) == 'number-of-timepoints': self.number_of_time_points = int(model_xml_level2.text) elif model_xml_level2.tag.lower( ) == 'number-of-interpolation-points': self.number_of_interpolation_points = int( model_xml_level2.text) elif model_xml_level2.tag.lower( ) == 'concentration-of-timepoints': self.concentration_of_time_points = int( model_xml_level2.text) elif model_xml_level2.tag.lower() == 'number-of-sources': self.number_of_sources = int(model_xml_level2.text) elif model_xml_level2.tag.lower() == 't0': self.t0 = float(model_xml_level2.text) elif model_xml_level2.tag.lower() == 'tmin': self.tmin = float(model_xml_level2.text) elif model_xml_level2.tag.lower() == 'tmax': self.tmax = float(model_xml_level2.text) elif model_xml_level2.tag.lower() == 'p0': self.p0 = model_xml_level2.text elif model_xml_level2.tag.lower() == 'v0': self.v0 = model_xml_level2.text elif model_xml_level2.tag.lower( ) == 'metric-parameters-file': # for metric learning self.metric_parameters_file = model_xml_level2.text elif model_xml_level2.tag.lower( ) == 'interpolation-points-file': # for metric learning self.interpolation_points_file = model_xml_level2.text elif model_xml_level2.tag.lower( ) == 'covariance-momenta-prior-normalized-dof': self.covariance_momenta_prior_normalized_dof = float( model_xml_level2.text) else: msg = 'Unknown entry while parsing the deformation-parameters section of the model xml: ' \ + model_xml_level2.tag warnings.warn(msg) elif model_xml_level1.tag.lower() == 'use-exp-parallelization': self.use_exp_parallelization = self._on_off_to_bool( model_xml_level1.text) else: msg = 'Unknown entry while parsing root of the model xml: ' + model_xml_level1.tag warnings.warn(msg)
def _further_initialization(self): if self.dense_mode: Settings().dense_mode = self.dense_mode print( '>> Dense mode activated. No distinction will be made between template and control points.' ) assert len(self.template_specifications) == 1, \ 'Only a single object can be considered when using the dense mode.' if not self.freeze_control_points: self.freeze_control_points = True msg = 'With active dense mode, the freeze_template (currently %s) and freeze_control_points ' \ '(currently %s) flags are redundant. Defaulting to freeze_control_points = True.' \ % (str(self.freeze_template), str(self.freeze_control_points)) warnings.warn(msg) if self.initial_control_points is not None: self.initial_control_points = None msg = 'With active dense mode, specifying initial_control_points is useless. Ignoring this xml entry.' warnings.warn(msg) if self.initial_cp_spacing < 0 and self.initial_control_points is None and not self.dense_mode: print( '>> No initial CP spacing given: using diffeo kernel width of ' + str(self.deformation_kernel_width)) self.initial_cp_spacing = self.deformation_kernel_width # Setting tensor types according to CUDA availability and user choices. if self._cuda_is_used: if not torch.cuda.is_available(): msg = 'CUDA seems to be unavailable. All computations will be carried out on CPU.' warnings.warn(msg) else: print( ">> CUDA is used at least in one operation, all operations will be done with FLOAT precision." ) if self.use_cuda: print(">> All tensors will be CUDA tensors.") Settings().tensor_scalar_type = torch.cuda.FloatTensor Settings().tensor_integer_type = torch.cuda.LongTensor else: print(">> Setting tensor type to float.") Settings().tensor_scalar_type = torch.FloatTensor # Setting the dimension. Settings().dimension = self.dimension # If longitudinal model and t0 is not initialized, initializes it. if (self.model_type == 'regression' or self.model_type == 'LongitudinalAtlas'.lower() or self.model_type == 'LongitudinalRegistration'.lower()) \ and (self.t0 is None or self.initial_time_shift_variance is None): total_number_of_visits = 0 mean_visit_age = 0.0 var_visit_age = 0.0 for i in range(len(self.visit_ages)): for j in range(len(self.visit_ages[i])): total_number_of_visits += 1 mean_visit_age += self.visit_ages[i][j] var_visit_age += self.visit_ages[i][j]**2 if total_number_of_visits > 0: mean_visit_age /= float(total_number_of_visits) var_visit_age = ( var_visit_age / float(total_number_of_visits) - mean_visit_age**2) if self.t0 is None: print('>> Initial t0 set to the mean visit age: %.2f' % mean_visit_age) self.t0 = mean_visit_age else: print( '>> Initial t0 set by the user to %.2f ; note that the mean visit age is %.2f' % (self.t0, mean_visit_age)) if not self.model_type == 'regression': if self.initial_time_shift_variance is None: print( '>> Initial time-shift std set to the empirical std of the visit ages: %.2f' % math.sqrt(var_visit_age)) self.initial_time_shift_variance = var_visit_age else: print(( '>> Initial time-shift std set by the user to %.2f ; note that the empirical std of ' 'the visit ages is %.2f') % (self.initial_time_shift_variance, math.sqrt(var_visit_age))) # Setting the number of threads in general settings Settings().number_of_threads = self.number_of_threads if self.number_of_threads > 1: print( ">> I will use", self.number_of_threads, "threads, and I set OMP_NUM_THREADS and torch_num_threads to 1." ) os.environ['OMP_NUM_THREADS'] = "1" torch.set_num_threads(1) else: print('>> Setting OMP_NUM_THREADS and torch_num_threads to 4.') os.environ['OMP_NUM_THREADS'] = "4" torch.set_num_threads(4) try: set_start_method("spawn") except RuntimeError as error: print('>> Warning: ' + str(error) + ' [ in xml_parameters ]. Ignoring.') self._initialize_state_file() # Freeze the fixed effects in case of a registration. if self.model_type == 'Registration'.lower(): self.freeze_template = True self.freeze_control_points = True elif self.model_type == 'LongitudinalRegistration'.lower(): self.freeze_template = True self.freeze_control_points = True self.freeze_momenta = True self.freeze_modulation_matrix = True self.freeze_reference_time = True self.freeze_time_shift_variance = True self.freeze_log_acceleration_variance = True self.freeze_noise_variance = True # Initialize the number of sources if needed. if self.model_type == 'LongitudinalAtlas'.lower() \ and self.initial_modulation_matrix is None and self.number_of_sources is None: self.number_of_sources = 4 print( '>> No initial modulation matrix given, neither a number of sources. ' 'The latter will be ARBITRARILY defaulted to 4.') if self.dimension <= 1: print( "Setting the number of sources to 0 because the dimension is 1." ) self.number_of_sources = 0 # Initialize the initial_log_acceleration_variance if needed. if (self.model_type == 'LongitudinalAtlas'.lower() or self.model_type == 'LongitudinalRegistration'.lower()) \ and self.initial_log_acceleration_variance is None: print( '>> The initial log-acceleration std fixed effect is ARBITRARILY set to 0.5' ) log_acceleration_std = 0.5 self.initial_log_acceleration_variance = (log_acceleration_std**2) # Image grid downsampling factor. if not self.downsampling_factor == 1: image_object_specs = [ (key, value) for key, value in self.template_specifications.items() if value['deformable_object_type'].lower() == 'image' ] if len(image_object_specs) > 2: raise RuntimeError('Only a single image object can be used.') elif len(image_object_specs) == 1: print('>> Setting the image grid downsampling factor to: %d.' % self.downsampling_factor) self.template_specifications[image_object_specs[0][0]][ 'downsampling_factor'] = self.downsampling_factor else: msg = 'The "downsampling_factor" parameter is useful only for image data, ' \ 'but none is considered here. Ignoring.' warnings.warn(msg)
def update(self): """ Update the geodesic, and compute the parallel transport of each column of the modulation matrix along this geodesic, ignoring the tangential components. """ # Update the geodesic. self.geodesic.update() # Convenient attributes for later use. self.times = self.geodesic._get_times() self.template_points_t = self.geodesic._get_template_points_trajectory( ) self.control_points_t = self.geodesic._get_control_points_trajectory() if self.transport_is_modified: # Initializes the projected_modulation_matrix_t attribute size. self.projected_modulation_matrix_t = \ [Variable(torch.zeros(self.modulation_matrix_t0.size()).type(Settings().tensor_scalar_type), requires_grad=False) for _ in range(len(self.control_points_t))] # Transport each column, ignoring the tangential components. for s in range(self.number_of_sources): space_shift_t0 = self.modulation_matrix_t0[:, s].contiguous( ).view(self.geodesic.momenta_t0.size()) space_shift_t = self.geodesic.parallel_transport( space_shift_t0, is_orthogonal=True) # Set the result correctly in the projected_modulation_matrix_t attribute. for t, space_shift in enumerate(space_shift_t): self.projected_modulation_matrix_t[ t][:, s] = space_shift.view(-1) self.transport_is_modified = False self.backward_extension = 0 self.forward_extension = 0 elif self.backward_extension > 0 or self.forward_extension > 0: # Initializes the extended projected_modulation_matrix_t variable. projected_modulation_matrix_t_extended = \ [Variable(torch.zeros(self.modulation_matrix_t0.size()).type(Settings().tensor_scalar_type), requires_grad=False) for _ in range(len(self.control_points_t))] # Transport each column, ignoring the tangential components. for s in range(self.number_of_sources): space_shift_t = [ elt[:, s].contiguous().view(self.geodesic.momenta_t0.size()) for elt in self.projected_modulation_matrix_t ] # print(len(self.control_points_t)) space_shift_t = self.geodesic.extend_parallel_transport( space_shift_t, self.backward_extension, self.forward_extension, is_orthogonal=True) for t, space_shift in enumerate(space_shift_t): projected_modulation_matrix_t_extended[ t][:, s] = space_shift.view(-1) self.projected_modulation_matrix_t = projected_modulation_matrix_t_extended self.backward_extension = 0 self.forward_extension = 0 assert len(self.template_points_t[list(self.template_points_t.keys())[0]]) == len(self.control_points_t) \ == len(self.times) == len(self.projected_modulation_matrix_t), \ "That's weird: len(self.template_points_t[list(self.template_points_t.keys())[0]]) = %d, " \ "len(self.control_points_t) = %d, len(self.times) = %d, len(self.projected_modulation_matrix_t) = %d" % \ (len(self.template_points_t[list(self.template_points_t.keys())[0]]), len(self.control_points_t), len(self.times), len(self.projected_modulation_matrix_t))
def create_regular_grid_of_points(box, spacing): """ Creates a regular grid of 2D or 3D points, as a numpy array of size nb_of_points x dimension. box: (dimension, 2) """ dimension = Settings().dimension axis = [] for d in range(dimension): min = box[d, 0] max = box[d, 1] length = max - min assert (length > 0) offset = 0.5 * (length - spacing * math.floor(length / spacing)) axis.append(np.arange(min + offset, max + 1e-10, spacing)) if dimension == 1: control_points = np.zeros((len(axis[0]), dimension)) control_points[:, 0] = axis[0].flatten() elif dimension == 2: x_axis, y_axis = np.meshgrid(axis[0], axis[1]) assert (x_axis.shape == y_axis.shape) number_of_control_points = x_axis.flatten().shape[0] control_points = np.zeros((number_of_control_points, dimension)) control_points[:, 0] = x_axis.flatten() control_points[:, 1] = y_axis.flatten() elif dimension == 3: x_axis, y_axis, z_axis = np.meshgrid(axis[0], axis[1], axis[2]) assert (x_axis.shape == y_axis.shape) assert (x_axis.shape == z_axis.shape) number_of_control_points = x_axis.flatten().shape[0] control_points = np.zeros((number_of_control_points, dimension)) control_points[:, 0] = x_axis.flatten() control_points[:, 1] = y_axis.flatten() control_points[:, 2] = z_axis.flatten() elif dimension == 4: x_axis, y_axis, z_axis, t_axis = np.meshgrid(axis[0], axis[1], axis[2], axis[3]) assert (x_axis.shape == y_axis.shape) assert (x_axis.shape == z_axis.shape) number_of_control_points = x_axis.flatten().shape[0] control_points = np.zeros((number_of_control_points, dimension)) control_points[:, 0] = x_axis.flatten() control_points[:, 1] = y_axis.flatten() control_points[:, 2] = z_axis.flatten() control_points[:, 3] = t_axis.flatten() else: raise RuntimeError('Invalid ambient space dimension.') return control_points
def main(): import logging logger = logging.getLogger(__name__) logger_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' # parse arguments parser = argparse.ArgumentParser(description='Deformetrica') parser.add_argument('model', type=str, help='model xml file') parser.add_argument('dataset', type=str, help='data-set xml file') parser.add_argument('optimization', type=str, help='optimization parameters xml file') # optional arguments parser.add_argument('-o', '--output', type=str, help='output folder') # logging levels: https://docs.python.org/2/library/logging.html#logging-levels parser.add_argument( '--verbosity', '-v', type=str, default='WARNING', choices=['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='set output verbosity') args = parser.parse_args() # set logging level try: log_level = logging.getLevelName(args.verbosity) logging.basicConfig(level=log_level, format=logger_format) except ValueError: logger.warning('Logging level was not recognized. Using INFO.') log_level = logging.INFO logger.debug('Using verbosity level: ' + args.verbosity) logging.basicConfig(level=log_level, format=logger_format) # Basic info printing logger.info(info()) """ Read xml files, set general settings, and call the adapted function. """ try: if args.output is None: logger.info('Creating the output directory: ' + Settings().output_dir) os.makedirs(Settings().output_dir) else: logger.info('Setting output directory to: ' + args.output) Settings().set_output_dir(args.output) except FileExistsError: pass logger.info('[ read_all_xmls function ]') xml_parameters = XmlParameters() xml_parameters.read_all_xmls(args.model, args.dataset, args.optimization) if xml_parameters.model_type == 'DeterministicAtlas'.lower() \ or xml_parameters.model_type == 'Registration'.lower(): estimate_deterministic_atlas(xml_parameters) elif xml_parameters.model_type == 'BayesianAtlas'.lower(): estimate_bayesian_atlas(xml_parameters) elif xml_parameters.model_type == 'Regression'.lower(): estimate_geodesic_regression(xml_parameters) elif xml_parameters.model_type == 'LongitudinalAtlas'.lower(): estimate_longitudinal_atlas(xml_parameters) elif xml_parameters.model_type == 'LongitudinalRegistration'.lower(): estimate_longitudinal_registration(xml_parameters) elif xml_parameters.model_type == 'Shooting'.lower(): run_shooting(xml_parameters) elif xml_parameters.model_type == 'ParallelTransport'.lower(): compute_parallel_transport(xml_parameters) elif xml_parameters.model_type == 'LongitudinalMetricLearning'.lower(): estimate_longitudinal_metric_model(xml_parameters) elif xml_parameters.model_type == 'LongitudinalMetricRegistration'.lower(): estimate_longitudinal_metric_registration(xml_parameters) else: raise RuntimeError( 'Unrecognized model-type: "' + xml_parameters.model_type + '". Check the corresponding field in the model.xml input file.')
def set_connectivity(self, connectivity): self.connectivity = torch.from_numpy(connectivity).type( Settings().tensor_integer_type) self.is_modified = True
def test_write_3D_array(self): momenta = read_3D_array(os.path.join(Settings().unit_tests_data_dir, "Momenta.txt")) write_3D_array(momenta, self.test_output_file_path) read = read_3D_array(self.test_output_file_path) self.assertTrue(np.allclose(momenta, read))
def test_read_3D_array(self): momenta = read_3D_array(os.path.join(Settings().unit_tests_data_dir, "Momenta.txt")) self.assertEqual(momenta.shape, (4, 72, 3)) self.assertTrue(np.allclose(momenta[0, 0], np.array([-0.0313538, -0.00373486, -0.0256917]))) self.assertTrue(np.allclose(momenta[0, -1], np.array([-0.518624, 1.47211, 0.880905]))) self.assertTrue(np.allclose(momenta[-1, -1], np.array([2.81286, -0.353167, -2.16408])))
def inverse_metric(self, q): return Variable( torch.eye(self.dimension).type(Settings().tensor_scalar_type))