Example #1
0
def run_shooting(xml_parameters):
    
    print('[ run_shooting function ]')
    print('')
    
    """
    Create the template object
    """
    
    t_list, t_name, t_name_extension, t_noise_variance, multi_object_attachment = \
        create_template_metadata(xml_parameters.template_specifications)
    
    print("Object list:", t_list)
    
    template = DeformableMultiObject()
    template.object_list = t_list
    template.update()
    
    """
    Reading Control points and momenta
    """
    
    # if not (os.path.exists(Settings().output_dir)): Settings().output_dir
    
    
    if not xml_parameters.initial_control_points is None:
        control_points = read_2D_array(xml_parameters.initial_control_points)
    else:
        raise ArgumentError('Please specify a path to control points to perform a shooting')
    
    if not xml_parameters.initial_momenta is None:
        momenta = read_3D_array(xml_parameters.initial_momenta)
    else:
        raise ArgumentError('Please specify a path to momenta to perform a shooting')
    
    template_data_numpy = template.get_points()
    template_data_torch = Variable(torch.from_numpy(template_data_numpy))
    
    momenta_torch = Variable(torch.from_numpy(momenta))
    control_points_torch = Variable(torch.from_numpy(control_points))
    
    exp = Exponential()
    exp.set_initial_control_points(control_points_torch)
    exp.set_initial_template_data(template_data_torch)
    exp.number_of_time_points = 10
    exp.kernel = kernel_factory.factory(xml_parameters.deformation_kernel_type, xml_parameters.deformation_kernel_width)
    exp.set_use_rk2(xml_parameters.use_rk2)
    
    for i in range(len(momenta_torch)):
        exp.set_initial_momenta(momenta_torch[i])
        exp.update()
        deformedPoints = exp.get_template_data()
        names = [elt + "_"+ str(i) for elt in t_name]
        exp.write_flow(names, t_name_extension, template)
        exp.write_control_points_and_momenta_flow("Shooting_"+str(i))
def _exp_parallelize(control_points, initial_momenta, projected_momenta,
                     xml_parameters):
    objects_list, objects_name, objects_name_extension, _, _ = create_template_metadata(
        xml_parameters.template_specifications)
    template = DeformableMultiObject()
    template.object_list = objects_list
    template.update()

    template_data = template.get_points()
    template_data_torch = Variable(
        torch.from_numpy(template_data).type(Settings().tensor_scalar_type))

    geodesic = Geodesic()
    geodesic.concentration_of_time_points = xml_parameters.concentration_of_time_points
    geodesic.set_kernel(
        kernel_factory.factory(xml_parameters.deformation_kernel_type,
                               xml_parameters.deformation_kernel_width))
    geodesic.set_use_rk2(xml_parameters.use_rk2)

    # Those are mandatory parameters.
    assert xml_parameters.tmin != -float(
        "inf"), "Please specify a minimum time for the geodesic trajectory"
    assert xml_parameters.tmax != float(
        "inf"), "Please specify a maximum time for the geodesic trajectory"

    geodesic.tmin = xml_parameters.tmin
    geodesic.tmax = xml_parameters.tmax
    if xml_parameters.t0 is None:
        geodesic.t0 = geodesic.tmin
    else:
        geodesic.t0 = xml_parameters.t0

    geodesic.set_momenta_t0(initial_momenta)
    geodesic.set_control_points_t0(control_points)
    geodesic.set_template_data_t0(template_data_torch)
    geodesic.update()

    # We write the flow of the geodesic

    geodesic.write("Regression", objects_name, objects_name_extension,
                   template)

    # Now we transport!
    parallel_transport_trajectory = geodesic.parallel_transport(
        projected_momenta)

    # Getting trajectory caracteristics:
    times = geodesic._get_times()
    control_points_traj = geodesic._get_control_points_trajectory()
    momenta_traj = geodesic._get_momenta_trajectory()
    template_data_traj = geodesic._get_template_data_trajectory()

    exponential = Exponential()
    exponential.number_of_time_points = xml_parameters.number_of_time_points
    exponential.set_kernel(
        kernel_factory.factory(xml_parameters.deformation_kernel_type,
                               xml_parameters.deformation_kernel_width))
    exponential.set_use_rk2(xml_parameters.use_rk2)

    # We save this trajectory, and the corresponding shape trajectory
    for i, (time, cp, mom, transported_mom, td) in enumerate(
            zip(times, control_points_traj, momenta_traj,
                parallel_transport_trajectory, template_data_traj)):
        # Writing the momenta/cps
        write_2D_array(
            cp.data.numpy(),
            "control_Points_tp_" + str(i) + "__age_" + str(time) + ".txt")
        write_3D_array(mom.data.numpy(),
                       "momenta_tp_" + str(i) + "__age_" + str(time) + ".txt")
        write_3D_array(
            transported_mom.data.numpy(),
            "transported_momenta_tp_" + str(i) + "__age_" + str(time) + ".txt")

        # Shooting from the geodesic:
        exponential.set_initial_template_data(td)
        exponential.set_initial_control_points(cp)
        exponential.set_initial_momenta(transported_mom)
        exponential.update()

        # Uncomment for massive writing, useful for debugging.
        # dir = "exp_"+str(i)+"_"+str(time)
        # if not(os.path.isdir(os.path.join(Settings().output_dir, dir))):
        #     os.mkdir(os.path.join(Settings().output_dir, dir))
        # exponential.write_flow([os.path.join(dir, elt) for elt in objects_name],
        #                        objects_name_extension,
        #                        template)
        # exponential.write_control_points_and_momenta_flow(os.path.join(dir, "cp_and_mom"))

        parallel_td = exponential.get_template_data()
        template.set_points(parallel_td)
        names = [
            objects_name[k] + "_parallel_curve_tp_" + str(i) + "__age_" +
            str(time) + "_" + objects_name_extension[k]
            for k in range(len(objects_name))
        ]
        template.write(names)