예제 #1
0
def run_shooting(xml_parameters):
    
    print('[ run_shooting function ]')
    print('')
    
    """
    Create the template object
    """
    
    t_list, t_name, t_name_extension, t_noise_variance, multi_object_attachment = \
        create_template_metadata(xml_parameters.template_specifications)
    
    print("Object list:", t_list)
    
    template = DeformableMultiObject()
    template.object_list = t_list
    template.update()
    
    """
    Reading Control points and momenta
    """
    
    # if not (os.path.exists(Settings().output_dir)): Settings().output_dir
    
    
    if not xml_parameters.initial_control_points is None:
        control_points = read_2D_array(xml_parameters.initial_control_points)
    else:
        raise ArgumentError('Please specify a path to control points to perform a shooting')
    
    if not xml_parameters.initial_momenta is None:
        momenta = read_3D_array(xml_parameters.initial_momenta)
    else:
        raise ArgumentError('Please specify a path to momenta to perform a shooting')
    
    template_data_numpy = template.get_points()
    template_data_torch = Variable(torch.from_numpy(template_data_numpy))
    
    momenta_torch = Variable(torch.from_numpy(momenta))
    control_points_torch = Variable(torch.from_numpy(control_points))
    
    exp = Exponential()
    exp.set_initial_control_points(control_points_torch)
    exp.set_initial_template_data(template_data_torch)
    exp.number_of_time_points = 10
    exp.kernel = kernel_factory.factory(xml_parameters.deformation_kernel_type, xml_parameters.deformation_kernel_width)
    exp.set_use_rk2(xml_parameters.use_rk2)
    
    for i in range(len(momenta_torch)):
        exp.set_initial_momenta(momenta_torch[i])
        exp.update()
        deformedPoints = exp.get_template_data()
        names = [elt + "_"+ str(i) for elt in t_name]
        exp.write_flow(names, t_name_extension, template)
        exp.write_control_points_and_momenta_flow("Shooting_"+str(i))
예제 #2
0
    def initialize_template_attributes(self, template_specifications):
        """
        Sets the Template, TemplateObjectsName, TemplateObjectsNameExtension, TemplateObjectsNorm,
        TemplateObjectsNormKernelType and TemplateObjectsNormKernelWidth attributes.
        """

        t_list, t_name, t_name_extension, t_noise_variance, t_multi_object_attachment = \
            create_template_metadata(template_specifications)

        self.template.object_list = t_list
        self.objects_name = t_name
        self.objects_name_extension = t_name_extension
        self.objects_noise_variance = t_noise_variance
        self.multi_object_attachment = t_multi_object_attachment
    def __init__(self,
                 template_specifications,
                 dimension=default.dimension,
                 tensor_scalar_type=default.tensor_scalar_type,
                 tensor_integer_type=default.tensor_integer_type,
                 number_of_threads=default.number_of_threads,
                 deformation_kernel_type=default.deformation_kernel_type,
                 deformation_kernel_width=default.deformation_kernel_width,
                 deformation_kernel_device=default.deformation_kernel_device,
                 shoot_kernel_type=default.shoot_kernel_type,
                 number_of_time_points=default.number_of_time_points,
                 freeze_template=default.freeze_template,
                 use_sobolev_gradient=default.use_sobolev_gradient,
                 smoothing_kernel_width=default.smoothing_kernel_width,
                 estimate_initial_velocity=default.estimate_initial_velocity,
                 initial_velocity_weight=default.initial_velocity_weight,
                 initial_control_points=default.initial_control_points,
                 freeze_control_points=default.freeze_control_points,
                 initial_cp_spacing=default.initial_cp_spacing,
                 initial_impulse_t=None,
                 initial_velocity=None,
                 **kwargs):

        AbstractStatisticalModel.__init__(self, name='AccelerationRegression')

        # Global-like attributes.
        self.dimension = dimension
        self.tensor_scalar_type = tensor_scalar_type
        self.tensor_integer_type = tensor_integer_type
        self.number_of_threads = number_of_threads

        # Declare model structure.
        self.fixed_effects['template_data'] = None
        self.fixed_effects['control_points'] = None

        self.freeze_template = freeze_template
        self.freeze_control_points = freeze_control_points

        # Deformation.
        self.acceleration_path = AccelerationPath(
            kernel=kernel_factory.factory(deformation_kernel_type,
                                          deformation_kernel_width,
                                          device=deformation_kernel_device),
            shoot_kernel_type=shoot_kernel_type,
            number_of_time_points=number_of_time_points)

        # Template.
        (object_list, self.objects_name, self.objects_name_extension,
         self.objects_noise_variance,
         self.multi_object_attachment) = create_template_metadata(
             template_specifications, self.dimension)

        self.template = DeformableMultiObject(object_list)
        self.template.update()

        template_data = self.template.get_data()

        # Set up the gompertz images A, B, and C
        intensities = template_data['image_intensities']
        self.fixed_effects['A'] = np.zeros(intensities.shape)
        self.fixed_effects['B'] = np.zeros(intensities.shape)
        self.fixed_effects['C'] = np.zeros(intensities.shape)

        self.number_of_objects = len(self.template.object_list)

        self.use_sobolev_gradient = use_sobolev_gradient
        self.smoothing_kernel_width = smoothing_kernel_width
        if self.use_sobolev_gradient:
            self.sobolev_kernel = kernel_factory.factory(
                deformation_kernel_type,
                smoothing_kernel_width,
                device=deformation_kernel_device)

        # Template data.
        self.fixed_effects['template_data'] = self.template.get_data()

        # Control points.
        self.fixed_effects['control_points'] = initialize_control_points(
            initial_control_points, self.template, initial_cp_spacing,
            deformation_kernel_width, self.dimension, False)

        self.estimate_initial_velocity = estimate_initial_velocity
        self.initial_velocity_weight = initial_velocity_weight

        self.number_of_control_points = len(
            self.fixed_effects['control_points'])
        self.number_of_time_points = number_of_time_points

        # Impulse
        self.fixed_effects['impulse_t'] = initialize_impulse(
            initial_impulse_t, self.number_of_time_points,
            self.number_of_control_points, self.dimension)
        if (self.estimate_initial_velocity):
            self.fixed_effects[
                'initial_velocity'] = initialize_initial_velocity(
                    initial_velocity, self.number_of_control_points,
                    self.dimension)
        logger.info('>> Creating the output directory: "' + output_dir + '"')
        os.makedirs(output_dir)

    xml_parameters = XmlParameters()
    xml_parameters._read_model_xml(model_xml_path)

    deformetrica = Deformetrica(output_dir=output_dir)
    template_specifications, model_options, _ = deformetrica.further_initialization(
        'Shooting', xml_parameters.template_specifications, get_model_options(xml_parameters))

    """
    Load the template, control points, momenta, modulation matrix.
    """

    # Template.
    t_list, objects_name, objects_name_extension, _, _ = create_template_metadata(template_specifications)

    template = DeformableMultiObject(t_list)
    template_data = {key: torch.from_numpy(value).type(model_options['tensor_scalar_type'])
                     for key, value in template.get_data().items()}
    template_points = {key: torch.from_numpy(value).type(model_options['tensor_scalar_type'])
                       for key, value in template.get_points().items()}

    # Control points.
    control_points = read_2D_array(model_options['initial_control_points'])
    logger.info('>> Reading ' + str(len(control_points)) + ' initial control points from file: '
          + model_options['initial_control_points'])
    control_points = torch.from_numpy(control_points).type(model_options['tensor_scalar_type'])

    # Momenta.
    momenta = read_3D_array(model_options['initial_momenta'])
def _exp_parallelize(control_points, initial_momenta, projected_momenta,
                     xml_parameters):
    objects_list, objects_name, objects_name_extension, _, _ = create_template_metadata(
        xml_parameters.template_specifications)
    template = DeformableMultiObject()
    template.object_list = objects_list
    template.update()

    template_data = template.get_points()
    template_data_torch = Variable(
        torch.from_numpy(template_data).type(Settings().tensor_scalar_type))

    geodesic = Geodesic()
    geodesic.concentration_of_time_points = xml_parameters.concentration_of_time_points
    geodesic.set_kernel(
        kernel_factory.factory(xml_parameters.deformation_kernel_type,
                               xml_parameters.deformation_kernel_width))
    geodesic.set_use_rk2(xml_parameters.use_rk2)

    # Those are mandatory parameters.
    assert xml_parameters.tmin != -float(
        "inf"), "Please specify a minimum time for the geodesic trajectory"
    assert xml_parameters.tmax != float(
        "inf"), "Please specify a maximum time for the geodesic trajectory"

    geodesic.tmin = xml_parameters.tmin
    geodesic.tmax = xml_parameters.tmax
    if xml_parameters.t0 is None:
        geodesic.t0 = geodesic.tmin
    else:
        geodesic.t0 = xml_parameters.t0

    geodesic.set_momenta_t0(initial_momenta)
    geodesic.set_control_points_t0(control_points)
    geodesic.set_template_data_t0(template_data_torch)
    geodesic.update()

    # We write the flow of the geodesic

    geodesic.write("Regression", objects_name, objects_name_extension,
                   template)

    # Now we transport!
    parallel_transport_trajectory = geodesic.parallel_transport(
        projected_momenta)

    # Getting trajectory caracteristics:
    times = geodesic._get_times()
    control_points_traj = geodesic._get_control_points_trajectory()
    momenta_traj = geodesic._get_momenta_trajectory()
    template_data_traj = geodesic._get_template_data_trajectory()

    exponential = Exponential()
    exponential.number_of_time_points = xml_parameters.number_of_time_points
    exponential.set_kernel(
        kernel_factory.factory(xml_parameters.deformation_kernel_type,
                               xml_parameters.deformation_kernel_width))
    exponential.set_use_rk2(xml_parameters.use_rk2)

    # We save this trajectory, and the corresponding shape trajectory
    for i, (time, cp, mom, transported_mom, td) in enumerate(
            zip(times, control_points_traj, momenta_traj,
                parallel_transport_trajectory, template_data_traj)):
        # Writing the momenta/cps
        write_2D_array(
            cp.data.numpy(),
            "control_Points_tp_" + str(i) + "__age_" + str(time) + ".txt")
        write_3D_array(mom.data.numpy(),
                       "momenta_tp_" + str(i) + "__age_" + str(time) + ".txt")
        write_3D_array(
            transported_mom.data.numpy(),
            "transported_momenta_tp_" + str(i) + "__age_" + str(time) + ".txt")

        # Shooting from the geodesic:
        exponential.set_initial_template_data(td)
        exponential.set_initial_control_points(cp)
        exponential.set_initial_momenta(transported_mom)
        exponential.update()

        # Uncomment for massive writing, useful for debugging.
        # dir = "exp_"+str(i)+"_"+str(time)
        # if not(os.path.isdir(os.path.join(Settings().output_dir, dir))):
        #     os.mkdir(os.path.join(Settings().output_dir, dir))
        # exponential.write_flow([os.path.join(dir, elt) for elt in objects_name],
        #                        objects_name_extension,
        #                        template)
        # exponential.write_control_points_and_momenta_flow(os.path.join(dir, "cp_and_mom"))

        parallel_td = exponential.get_template_data()
        template.set_points(parallel_td)
        names = [
            objects_name[k] + "_parallel_curve_tp_" + str(i) + "__age_" +
            str(time) + "_" + objects_name_extension[k]
            for k in range(len(objects_name))
        ]
        template.write(names)