Ejemplo n.º 1
0
 def setUp(self):
     self.test_on_device = 'cuda:0'
     self.kernel_instance = kernel_factory.factory(
         kernel_factory.Type.TorchCudaKernel,
         kernel_width=1.,
         device=self.test_on_device)
     super().setUp()
Ejemplo n.º 2
0
def create_template_metadata(template_specifications):
    """
    Creates a longitudinal dataset object from xml parameters.
    """

    objects_list = []
    objects_name = []
    objects_noise_variance = []
    objects_name_extension = []
    objects_norm = []
    objects_norm_kernel_type = []
    objects_norm_kernel_width = []

    for object_id, object in template_specifications.items():
        filename = object['filename']
        object_type = object['deformable_object_type'].lower()

        assert object_type in [
            'SurfaceMesh'.lower(), 'PolyLine'.lower(), 'PointCloud'.lower(),
            'Landmark'.lower(), 'Image'.lower()
        ], "Unknown object type."

        root, extension = splitext(filename)
        reader = DeformableObjectReader()

        objects_list.append(reader.create_object(filename, object_type))
        objects_name.append(object_id)
        objects_name_extension.append(extension)

        if object['noise_std'] < 0:
            objects_noise_variance.append(-1.0)
        else:
            objects_noise_variance.append(object['noise_std']**2)

        object_norm = _get_norm_for_object(object, object_id)

        objects_norm.append(object_norm)

        if object_norm in ['current', 'varifold']:
            objects_norm_kernel_type.append(object['kernel_type'])
            objects_norm_kernel_width.append(float(object['kernel_width']))

        else:
            objects_norm_kernel_type.append("no_kernel")
            objects_norm_kernel_width.append(0.)

        # Optional grid downsampling parameter for image data.
        if object_type == 'image' and 'downsampling_factor' in list(
                object.keys()):
            objects_list[-1].downsampling_factor = object[
                'downsampling_factor']

    multi_object_attachment = MultiObjectAttachment()
    multi_object_attachment.attachment_types = objects_norm
    for k in range(len(objects_norm)):
        multi_object_attachment.kernels.append(
            kernel_factory.factory(objects_norm_kernel_type[k],
                                   objects_norm_kernel_width[k]))

    return objects_list, objects_name, objects_name_extension, objects_noise_variance, multi_object_attachment
Ejemplo n.º 3
0
def create_template_metadata(template_specifications, dimension=None):
    """
    Creates a longitudinal dataset object from xml parameters.
    """

    objects_list = []
    objects_name = []
    objects_noise_variance = []
    objects_name_extension = []
    objects_norm = []
    objects_norm_kernels = []

    for object_id, object in template_specifications.items():
        filename = object['filename']
        object_type = object['deformable_object_type'].lower()

        assert object_type in ['SurfaceMesh'.lower(), 'PolyLine'.lower(), 'PointCloud'.lower(), 'Landmark'.lower(),
                               'Image'.lower()], "Unknown object type."

        root, extension = splitext(filename)
        reader = DeformableObjectReader()

        objects_list.append(reader.create_object(filename, object_type, dimension=dimension))
        objects_name.append(object_id)
        objects_name_extension.append(extension)

        object_norm = _get_norm_for_object(object, object_id)

        objects_norm.append(object_norm)

        if object_norm in ['current', 'pointcloud', 'varifold']:

            objects_norm_kernels.append(kernel_factory.factory(
                'torch',
                object['kernel_width'],
                device=object['kernel_device'] if 'kernel_device' in object else default.deformation_kernel_device))
        else:

            objects_norm_kernels.append(kernel_factory.factory(kernel_factory.Type.NO_KERNEL))

        # Optional grid downsampling parameter for image data.
        if object_type == 'image' and 'downsampling_factor' in list(object.keys()):
            objects_list[-1].downsampling_factor = object['downsampling_factor']

    multi_object_attachment = MultiObjectAttachment(objects_norm, objects_norm_kernels)

    return objects_list, objects_name, objects_name_extension, multi_object_attachment
Ejemplo n.º 4
0
def run_shooting(xml_parameters):
    
    print('[ run_shooting function ]')
    print('')
    
    """
    Create the template object
    """
    
    t_list, t_name, t_name_extension, t_noise_variance, multi_object_attachment = \
        create_template_metadata(xml_parameters.template_specifications)
    
    print("Object list:", t_list)
    
    template = DeformableMultiObject()
    template.object_list = t_list
    template.update()
    
    """
    Reading Control points and momenta
    """
    
    # if not (os.path.exists(Settings().output_dir)): Settings().output_dir
    
    
    if not xml_parameters.initial_control_points is None:
        control_points = read_2D_array(xml_parameters.initial_control_points)
    else:
        raise ArgumentError('Please specify a path to control points to perform a shooting')
    
    if not xml_parameters.initial_momenta is None:
        momenta = read_3D_array(xml_parameters.initial_momenta)
    else:
        raise ArgumentError('Please specify a path to momenta to perform a shooting')
    
    template_data_numpy = template.get_points()
    template_data_torch = Variable(torch.from_numpy(template_data_numpy))
    
    momenta_torch = Variable(torch.from_numpy(momenta))
    control_points_torch = Variable(torch.from_numpy(control_points))
    
    exp = Exponential()
    exp.set_initial_control_points(control_points_torch)
    exp.set_initial_template_data(template_data_torch)
    exp.number_of_time_points = 10
    exp.kernel = kernel_factory.factory(xml_parameters.deformation_kernel_type, xml_parameters.deformation_kernel_width)
    exp.set_use_rk2(xml_parameters.use_rk2)
    
    for i in range(len(momenta_torch)):
        exp.set_initial_momenta(momenta_torch[i])
        exp.update()
        deformedPoints = exp.get_template_data()
        names = [elt + "_"+ str(i) for elt in t_name]
        exp.write_flow(names, t_name_extension, template)
        exp.write_control_points_and_momenta_flow("Shooting_"+str(i))
Ejemplo n.º 5
0
    def test_parallel_transport(self):
        """
        test the parallel transport on a chosen example converges towards the truth (checked from C++ deformetrica)
        """
        control_points = read_2D_array(
            os.path.join(Settings().unit_tests_data_dir, "parallel_transport",
                         "control_points.txt"))
        momenta = read_3D_array(
            os.path.join(Settings().unit_tests_data_dir, "parallel_transport",
                         "geodesic_momenta.txt"))
        momenta_to_transport = read_3D_array(
            os.path.join(Settings().unit_tests_data_dir, "parallel_transport",
                         "momenta_to_transport.txt"))
        transported_momenta_truth = read_3D_array(
            os.path.join(Settings().unit_tests_data_dir, "parallel_transport",
                         "ground_truth_transport.txt"))

        # control_points = np.array([[0.1, 2., 0.2]])
        # momenta = np.array([[1., 0., 0.]])
        # momenta_to_transport = np.array([[0.2, 0.3, 0.4]])

        control_points_torch = Variable(
            torch.from_numpy(control_points).type(
                Settings().tensor_scalar_type))
        momenta_torch = Variable(
            torch.from_numpy(momenta).type(Settings().tensor_scalar_type))
        momenta_to_transport_torch = Variable(
            torch.from_numpy(momenta_to_transport).type(
                Settings().tensor_scalar_type))

        geodesic = Geodesic()
        geodesic.set_kernel(kernel_factory.factory('torch', 0.01))
        geodesic.set_use_rk2(True)

        errors = []

        concentration = 10
        geodesic.concentration_of_time_points = concentration

        geodesic.tmin = 0.
        geodesic.tmax = 9.
        geodesic.t0 = 0.
        geodesic.set_momenta_t0(momenta_torch)
        geodesic.set_control_points_t0(control_points_torch)
        geodesic.update()

        # Now we transport!
        transported_momenta = geodesic.parallel_transport(
            momenta_to_transport_torch,
            is_orthogonal=False)[-1].detach().numpy()
        self.assertTrue(
            np.allclose(transported_momenta,
                        transported_momenta_truth,
                        rtol=1e-4,
                        atol=1e-1))
Ejemplo n.º 6
0
    def __init__(self, kernel, tensor_size, tensor_initial_device='cpu'):
        # tensor_size = (4, 3)
        # print('BenchRunner::__init()__ getting kernel and initializing tensors with size ' + str(tensor_size))
        self.kernel_instance = kernel_factory.factory(kernel, kernel_width=1.)

        self.x = torch.rand(tensor_size,
                            device=torch.device(tensor_initial_device))
        self.y = self.x.clone()
        self.p = torch.ones(tensor_size,
                            device=torch.device(tensor_initial_device))

        # run once for warm-up: cuda pre-compile
        self.res = self.kernel_instance.convolve(self.x, self.y, self.p)
    def get_template_points_exponential(self, time, sources):

        # Assert for coherent length of attribute lists.
        assert len(self.template_points_t[list(self.template_points_t.keys())[0]]) == len(self.control_points_t) \
               == len(self.projected_modulation_matrix_t) == len(self.times)

        # Initialize the returned exponential.
        exponential = Exponential()
        exponential.kernel = kernel_factory.factory(
            self.exponential.kernel.kernel_type,
            self.exponential.kernel.kernel_width)
        exponential.number_of_time_points = self.exponential.number_of_time_points
        exponential.use_rk2 = self.exponential.use_rk2

        # Deal with the special case of a geodesic reduced to a single point.
        if len(self.times) == 1:
            print(
                '>> The spatiotemporal reference frame geodesic seems to be reduced to a single point.'
            )
            exponential.set_initial_template_points({
                key: value[0]
                for key, value in self.template_points_t.items()
            })
            exponential.set_initial_control_points(self.control_points_t[0])
            exponential.set_initial_momenta(
                torch.mm(self.projected_modulation_matrix_t[0],
                         sources.unsqueeze(1)).view(
                             self.geodesic.momenta_t0.size()))
            return exponential

        # Standard case.
        index, weight_left, weight_right = self._get_interpolation_index_and_weights(
            time)
        template_points = {
            key: weight_left * value[index - 1] + weight_right * value[index]
            for key, value in self.template_points_t.items()
        }
        control_points = weight_left * self.control_points_t[
            index - 1] + weight_right * self.control_points_t[index]
        modulation_matrix = weight_left * self.projected_modulation_matrix_t[index - 1] \
                            + weight_right * self.projected_modulation_matrix_t[index]
        space_shift = torch.mm(modulation_matrix, sources.unsqueeze(1)).view(
            self.geodesic.momenta_t0.size())

        exponential.set_initial_template_points(template_points)
        exponential.set_initial_control_points(control_points)
        exponential.set_initial_momenta(space_shift)
        return exponential
    def __init__(self, kernel_type, kernel_device='CPU', use_cuda=False, data_type='landmark', data_size='small'):

        np.random.seed(42)
        kernel_width = 10.

        if use_cuda:
            self.tensor_scalar_type = torch.cuda.FloatTensor
        else:
            self.tensor_scalar_type = torch.FloatTensor

        self.exponential = Exponential(kernel=kernel_factory.factory(kernel_type, kernel_width, kernel_device),
                                       number_of_time_points=11, use_rk2_for_flow=False, use_rk2_for_shoot=False)

        if data_type.lower() == 'landmark':
            reader = DeformableObjectReader()
            if data_size == 'small':
                surface_mesh = reader.create_object(path_to_small_surface_mesh_1, 'SurfaceMesh')
                self.control_points = create_regular_grid_of_points(surface_mesh.bounding_box, kernel_width, surface_mesh.dimension)
            elif data_size == 'large':
                surface_mesh = reader.create_object(path_to_large_surface_mesh_1, 'SurfaceMesh')
                self.control_points = create_regular_grid_of_points(surface_mesh.bounding_box, kernel_width, surface_mesh.dimension)
            else:
                connectivity = np.array(list(itertools.combinations(range(100), 3))[:int(data_size)])  # up to ~16k.
                surface_mesh = SurfaceMesh(3)
                surface_mesh.set_points(np.random.randn(np.max(connectivity) + 1, surface_mesh.dimension))
                surface_mesh.set_connectivity(connectivity)
                surface_mesh.update()
                self.control_points = np.random.randn(int(data_size) // 10, 3)
            # self.template.object_list.append(surface_mesh)
            self.template = DeformableMultiObject([surface_mesh])

        elif data_type.lower() == 'image':
            image = Image(3)
            image.set_intensities(np.random.randn(int(data_size), int(data_size), int(data_size)))
            image.set_affine(np.eye(4))
            image.downsampling_factor = 5.
            image.update()
            self.control_points = create_regular_grid_of_points(image.bounding_box, kernel_width, image.dimension)
            self.control_points = remove_useless_control_points(self.control_points, image, kernel_width)
            # self.template.object_list.append(image)
            self.template = DeformableMultiObject([image])

        else:
            raise RuntimeError('Unknown data_type argument. Choose between "landmark" or "image".')

        # self.template.update()
        self.momenta = np.random.randn(*self.control_points.shape)
Ejemplo n.º 9
0
def compute_sobolev_gradient(template_gradient,
                             smoothing_kernel_width,
                             template,
                             square_root=False):
    """
    Smoothing of the template gradient (for landmarks).
    Fully torch input / outputs.
    """
    template_sobolev_gradient = torch.zeros(template_gradient.size()).type(
        Settings().tensor_scalar_type)

    kernel = kernel_factory.factory(kernel_factory.Type.TorchKernel)
    kernel.kernel_width = smoothing_kernel_width

    cursor = 0
    for template_object in template.object_list:
        # TODO : assert if obj is image or not.
        object_data = torch.from_numpy(template_object.get_points()).type(
            Settings().tensor_scalar_type)

        if square_root:
            kernel_matrix = kernel.get_kernel_matrix(object_data).data.numpy()
            kernel_matrix_sqrt = Variable(torch.from_numpy(
                scipy.linalg.sqrtm(kernel_matrix).real).type(
                    Settings().tensor_scalar_type),
                                          requires_grad=False)
            template_sobolev_gradient[cursor:cursor +
                                      len(object_data)] = torch.mm(
                                          kernel_matrix_sqrt,
                                          template_gradient[cursor:cursor +
                                                            len(object_data)])
        else:
            template_sobolev_gradient[cursor:cursor +
                                      len(object_data)] = kernel.convolve(
                                          object_data, object_data,
                                          template_gradient[cursor:cursor +
                                                            len(object_data)])

        cursor += len(object_data)

    return template_sobolev_gradient
def compute_distance_squared(path_to_mesh_1,
                             path_to_mesh_2,
                             deformable_object_type,
                             attachment_type,
                             kernel_width=None):
    reader = DeformableObjectReader()
    object_1 = reader.create_object(path_to_mesh_1,
                                    deformable_object_type.lower())
    object_2 = reader.create_object(path_to_mesh_2,
                                    deformable_object_type.lower())

    multi_object_1 = DeformableMultiObject([object_1])
    multi_object_2 = DeformableMultiObject([object_2])
    multi_object_attachment = MultiObjectAttachment(
        [attachment_type], [kernel_factory.factory('torch', kernel_width)])

    return multi_object_attachment.compute_distances(
        {
            key: torch.from_numpy(value)
            for key, value in multi_object_1.get_points().items()
        }, multi_object_1, multi_object_2).data.cpu().numpy()
    def __init__(self, kernel=default.deformation_kernel, shoot_kernel_type=None, number_of_time_points=None,
                 initial_control_points=None, control_points_t=None, impulse_t=None, initial_velocity=None,
                 initial_template_points=None, template_points_t=None):

        self.kernel = kernel

        if shoot_kernel_type is not None:
            self.shoot_kernel = kernel_factory.factory(shoot_kernel_type, kernel_width=kernel.kernel_width, device=kernel.device)
        else:
            self.shoot_kernel = self.kernel

        self.number_of_time_points = number_of_time_points
        # Initial position of control points
        self.initial_control_points = initial_control_points
        # Control points trajectory
        self.control_points_t = control_points_t
        # Momenta trajectory
        self.impulse_t = impulse_t
        self.initial_velocity = initial_velocity
        
        # Initial template points
        self.initial_template_points = initial_template_points
        # Trajectory of the whole vertices of landmark type at different time steps.
        self.template_points_t = template_points_t
    def __init__(self,
                 template_specifications,
                 dimension=default.dimension,
                 tensor_scalar_type=default.tensor_scalar_type,
                 tensor_integer_type=default.tensor_integer_type,
                 number_of_threads=default.number_of_threads,
                 deformation_kernel_type=default.deformation_kernel_type,
                 deformation_kernel_width=default.deformation_kernel_width,
                 deformation_kernel_device=default.deformation_kernel_device,
                 shoot_kernel_type=default.shoot_kernel_type,
                 number_of_time_points=default.number_of_time_points,
                 freeze_template=default.freeze_template,
                 use_sobolev_gradient=default.use_sobolev_gradient,
                 smoothing_kernel_width=default.smoothing_kernel_width,
                 estimate_initial_velocity=default.estimate_initial_velocity,
                 initial_velocity_weight=default.initial_velocity_weight,
                 initial_control_points=default.initial_control_points,
                 freeze_control_points=default.freeze_control_points,
                 initial_cp_spacing=default.initial_cp_spacing,
                 initial_impulse_t=None,
                 initial_velocity=None,
                 **kwargs):

        AbstractStatisticalModel.__init__(self, name='AccelerationRegression')

        # Global-like attributes.
        self.dimension = dimension
        self.tensor_scalar_type = tensor_scalar_type
        self.tensor_integer_type = tensor_integer_type
        self.number_of_threads = number_of_threads

        # Declare model structure.
        self.fixed_effects['template_data'] = None
        self.fixed_effects['control_points'] = None

        self.freeze_template = freeze_template
        self.freeze_control_points = freeze_control_points

        # Deformation.
        self.acceleration_path = AccelerationPath(
            kernel=kernel_factory.factory(deformation_kernel_type,
                                          deformation_kernel_width,
                                          device=deformation_kernel_device),
            shoot_kernel_type=shoot_kernel_type,
            number_of_time_points=number_of_time_points)

        # Template.
        (object_list, self.objects_name, self.objects_name_extension,
         self.objects_noise_variance,
         self.multi_object_attachment) = create_template_metadata(
             template_specifications, self.dimension)

        self.template = DeformableMultiObject(object_list)
        self.template.update()

        template_data = self.template.get_data()

        # Set up the gompertz images A, B, and C
        intensities = template_data['image_intensities']
        self.fixed_effects['A'] = np.zeros(intensities.shape)
        self.fixed_effects['B'] = np.zeros(intensities.shape)
        self.fixed_effects['C'] = np.zeros(intensities.shape)

        self.number_of_objects = len(self.template.object_list)

        self.use_sobolev_gradient = use_sobolev_gradient
        self.smoothing_kernel_width = smoothing_kernel_width
        if self.use_sobolev_gradient:
            self.sobolev_kernel = kernel_factory.factory(
                deformation_kernel_type,
                smoothing_kernel_width,
                device=deformation_kernel_device)

        # Template data.
        self.fixed_effects['template_data'] = self.template.get_data()

        # Control points.
        self.fixed_effects['control_points'] = initialize_control_points(
            initial_control_points, self.template, initial_cp_spacing,
            deformation_kernel_width, self.dimension, False)

        self.estimate_initial_velocity = estimate_initial_velocity
        self.initial_velocity_weight = initial_velocity_weight

        self.number_of_control_points = len(
            self.fixed_effects['control_points'])
        self.number_of_time_points = number_of_time_points

        # Impulse
        self.fixed_effects['impulse_t'] = initialize_impulse(
            initial_impulse_t, self.number_of_time_points,
            self.number_of_control_points, self.dimension)
        if (self.estimate_initial_velocity):
            self.fixed_effects[
                'initial_velocity'] = initialize_initial_velocity(
                    initial_velocity, self.number_of_control_points,
                    self.dimension)
Ejemplo n.º 13
0
def instantiate_longitudinal_atlas_model(xml_parameters,
                                         dataset=None,
                                         ignore_noise_variance=False):
    model = LongitudinalAtlas()

    # Deformation object -----------------------------------------------------------------------------------------------
    model.spatiotemporal_reference_frame.set_kernel(
        kernel_factory.factory(xml_parameters.deformation_kernel_type,
                               xml_parameters.deformation_kernel_width))
    model.spatiotemporal_reference_frame.set_concentration_of_time_points(
        xml_parameters.concentration_of_time_points)
    model.spatiotemporal_reference_frame.set_number_of_time_points(
        xml_parameters.number_of_time_points)
    model.spatiotemporal_reference_frame.set_use_rk2(xml_parameters.use_rk2)

    # Initial fixed effects and associated priors ----------------------------------------------------------------------
    # Template.
    model.is_frozen['template_data'] = xml_parameters.freeze_template
    model.initialize_template_attributes(
        xml_parameters.template_specifications)
    model.use_sobolev_gradient = xml_parameters.use_sobolev_gradient
    model.smoothing_kernel_width = xml_parameters.deformation_kernel_width * xml_parameters.sobolev_kernel_width_ratio
    model.initialize_template_data_variables()

    # Control points.
    model.is_frozen['control_points'] = xml_parameters.freeze_control_points
    if xml_parameters.initial_control_points is not None:
        control_points = read_2D_array(xml_parameters.initial_control_points)
        print('>> Reading ' + str(len(control_points)) +
              ' initial control points from file: ' +
              xml_parameters.initial_control_points)
        model.set_control_points(control_points)
    else:
        model.initial_cp_spacing = xml_parameters.initial_cp_spacing
    model.initialize_control_points_variables()

    # Momenta.
    model.is_frozen['momenta'] = xml_parameters.freeze_momenta
    if not xml_parameters.initial_momenta is None:
        momenta = read_3D_array(xml_parameters.initial_momenta)
        print('>> Reading ' + str(len(momenta)) +
              ' initial momenta from file: ' + xml_parameters.initial_momenta)
        model.set_momenta(momenta)
    model.initialize_momenta_variables()

    # Modulation matrix.
    model.is_frozen[
        'modulation_matrix'] = xml_parameters.freeze_modulation_matrix
    if not xml_parameters.initial_modulation_matrix is None:
        modulation_matrix = read_2D_array(
            xml_parameters.initial_modulation_matrix)
        if len(modulation_matrix.shape) == 1:
            modulation_matrix = modulation_matrix.reshape(-1, 1)
        print('>> Reading ' + str(modulation_matrix.shape[1]) +
              '-source initial modulation matrix from file: ' +
              xml_parameters.initial_modulation_matrix)
        model.set_modulation_matrix(modulation_matrix)
    else:
        model.number_of_sources = xml_parameters.number_of_sources
    model.initialize_modulation_matrix_variables()

    # Reference time.
    model.is_frozen['reference_time'] = xml_parameters.freeze_reference_time
    model.set_reference_time(xml_parameters.t0)
    model.priors['reference_time'].set_variance(
        xml_parameters.initial_time_shift_variance)
    model.initialize_reference_time_variables()

    # Time-shift variance.
    model.is_frozen[
        'time_shift_variance'] = xml_parameters.freeze_time_shift_variance
    model.set_time_shift_variance(xml_parameters.initial_time_shift_variance)

    # Log-acceleration.
    model.is_frozen[
        'log_acceleration_variance'] = xml_parameters.freeze_log_acceleration_variance
    model.individual_random_effects['log_acceleration'].set_mean(
        xml_parameters.initial_log_acceleration_mean)
    model.set_log_acceleration_variance(
        xml_parameters.initial_log_acceleration_variance)

    # Initial random effects realizations ------------------------------------------------------------------------------
    number_of_subjects = len(xml_parameters.dataset_filenames)
    total_number_of_observations = sum(
        [len(elt) for elt in xml_parameters.dataset_filenames])

    # Onset ages.
    if xml_parameters.initial_onset_ages is not None:
        onset_ages = read_2D_array(xml_parameters.initial_onset_ages)
        print('>> Reading initial onset ages from file: ' +
              xml_parameters.initial_onset_ages)
    else:
        onset_ages = np.zeros(
            (number_of_subjects, )) + model.get_reference_time()
        print(
            '>> Initializing all onset ages to the initial reference time: %.2f'
            % model.get_reference_time())

    # Log-accelerations.
    if xml_parameters.initial_log_accelerations is not None:
        log_accelerations = read_2D_array(
            xml_parameters.initial_log_accelerations)
        print('>> Reading initial log-accelerations from file: ' +
              xml_parameters.initial_log_accelerations)
    else:
        log_accelerations = np.zeros((number_of_subjects, ))
        print('>> Initializing all log-accelerations to zero.')

    # Sources.
    if xml_parameters.initial_sources is not None:
        sources = read_2D_array(xml_parameters.initial_sources).reshape(
            (-1, model.number_of_sources))
        print('>> Reading initial sources from file: ' +
              xml_parameters.initial_sources)
    else:
        sources = np.zeros((number_of_subjects, model.number_of_sources))
        print('>> Initializing all sources to zero')

    # Final gathering.
    individual_RER = {}
    individual_RER['sources'] = sources
    individual_RER['onset_age'] = onset_ages
    individual_RER['log_acceleration'] = log_accelerations

    # Special case of the noise variance -------------------------------------------------------------------------------
    model.is_frozen['noise_variance'] = xml_parameters.freeze_noise_variance
    initial_noise_variance = model.get_noise_variance()

    # Compute residuals if needed.
    if not ignore_noise_variance:

        # Compute initial residuals if needed.
        if np.min(initial_noise_variance) < 0:

            template_data, template_points, control_points, momenta, modulation_matrix \
                = model._fixed_effects_to_torch_tensors(False)
            sources, onset_ages, log_accelerations = model._individual_RER_to_torch_tensors(
                individual_RER, False)
            absolute_times, tmin, tmax = model._compute_absolute_times(
                dataset.times, onset_ages, log_accelerations)
            model._update_spatiotemporal_reference_frame(
                template_points, control_points, momenta, modulation_matrix,
                tmin, tmax)
            residuals = model._compute_residuals(dataset, template_data,
                                                 absolute_times, sources)

            residuals_per_object = np.zeros((model.number_of_objects, ))
            for i in range(len(residuals)):
                for j in range(len(residuals[i])):
                    residuals_per_object += residuals[i][j].data.numpy()

            # Initialize noise variance fixed effect, and the noise variance prior if needed.
            for k, obj in enumerate(
                    xml_parameters.template_specifications.values()):
                dof = total_number_of_observations * obj['noise_variance_prior_normalized_dof'] * \
                      model.objects_noise_dimension[k]
                nv = 0.01 * residuals_per_object[k] / dof

                if initial_noise_variance[k] < 0:
                    print(
                        '>> Initial noise variance set to %.2f based on the initial mean residual value.'
                        % nv)
                    model.objects_noise_variance[k] = nv

        # Initialize the dof if needed.
        if not model.is_frozen['noise_variance']:
            for k, obj in enumerate(
                    xml_parameters.template_specifications.values()):
                dof = total_number_of_observations * obj['noise_variance_prior_normalized_dof'] * \
                      model.objects_noise_dimension[k]
                model.priors['noise_variance'].degrees_of_freedom.append(dof)

    # Final initialization steps by the model object itself ------------------------------------------------------------
    model.update()

    return model, individual_RER
    momenta = torch.from_numpy(momenta).type(model_options['tensor_scalar_type'])

    # Modulation matrix.
    modulation_matrix = read_2D_array(model_options['initial_modulation_matrix'])
    if len(modulation_matrix.shape) == 1:
        modulation_matrix = modulation_matrix.reshape(-1, 1)
    logger.info('>> Reading ' + str(modulation_matrix.shape[1]) + '-source initial modulation matrix from file: '
          + model_options['initial_modulation_matrix'])
    modulation_matrix = torch.from_numpy(modulation_matrix).type(model_options['tensor_scalar_type'])

    """
    Instantiate the spatiotemporal reference frame, update and write.
    """

    spatiotemporal_reference_frame = SpatiotemporalReferenceFrame(
        kernel=kernel_factory.factory(kernel_type=model_options['deformation_kernel_type'],
                                      kernel_width=model_options['deformation_kernel_width']),
        concentration_of_time_points=model_options['concentration_of_time_points'],
        number_of_time_points=model_options['number_of_time_points'],
        use_rk2_for_shoot=model_options['use_rk2_for_shoot'],
        use_rk2_for_flow=model_options['use_rk2_for_flow'],
    )

    spatiotemporal_reference_frame.set_template_points_t0(template_points)
    spatiotemporal_reference_frame.set_control_points_t0(control_points)
    spatiotemporal_reference_frame.set_momenta_t0(momenta)
    spatiotemporal_reference_frame.set_modulation_matrix_t0(modulation_matrix)
    spatiotemporal_reference_frame.set_t0(model_options['t0'])
    spatiotemporal_reference_frame.set_tmin(model_options['tmin'])
    spatiotemporal_reference_frame.set_tmax(model_options['tmax'])
    spatiotemporal_reference_frame.update()
    def __init__(self,
                 kernel_type,
                 kernel_device='CPU',
                 use_cuda=False,
                 data_size='small'):

        np.random.seed(42)
        kernel_width = 10.
        tensor_scalar_type = default.tensor_scalar_type

        if kernel_device.upper() == 'CPU':
            tensor_scalar_type = torch.FloatTensor
        elif kernel_device.upper() == 'GPU':
            tensor_scalar_type = torch.cuda.FloatTensor
        else:
            raise RuntimeError

        self.multi_object_attachment = MultiObjectAttachment(['varifold'], [
            kernel_factory.factory(
                kernel_type, kernel_width, device=kernel_device)
        ])

        self.kernel = kernel_factory.factory(kernel_type,
                                             kernel_width,
                                             device=kernel_device)

        reader = DeformableObjectReader()

        if data_size == 'small':
            self.surface_mesh_1 = reader.create_object(
                path_to_small_surface_mesh_1, 'SurfaceMesh',
                tensor_scalar_type)
            self.surface_mesh_2 = reader.create_object(
                path_to_small_surface_mesh_2, 'SurfaceMesh',
                tensor_scalar_type)
            self.surface_mesh_1_points = tensor_scalar_type(
                self.surface_mesh_1.get_points())
        elif data_size == 'large':
            self.surface_mesh_1 = reader.create_object(
                path_to_large_surface_mesh_1, 'SurfaceMesh',
                tensor_scalar_type)
            self.surface_mesh_2 = reader.create_object(
                path_to_large_surface_mesh_2, 'SurfaceMesh',
                tensor_scalar_type)
            self.surface_mesh_1_points = tensor_scalar_type(
                self.surface_mesh_1.get_points())
        else:
            data_size = int(data_size)
            connectivity = np.array(
                list(itertools.combinations(range(100),
                                            3))[:data_size])  # up to ~16k.
            self.surface_mesh_1 = SurfaceMesh(3)
            self.surface_mesh_1.set_points(
                np.random.randn(np.max(connectivity) + 1, 3))
            self.surface_mesh_1.set_connectivity(connectivity)
            self.surface_mesh_1.update()
            self.surface_mesh_2 = SurfaceMesh(3)
            self.surface_mesh_2.set_points(
                np.random.randn(np.max(connectivity) + 1, 3))
            self.surface_mesh_2.set_connectivity(connectivity)
            self.surface_mesh_2.update()
            self.surface_mesh_1_points = tensor_scalar_type(
                self.surface_mesh_1.get_points())
Ejemplo n.º 16
0
    def test_geodesic_shooting(self):
        """
        Test the shooting with a single cp. tests with (tmin=t0=0,tmax=1 ; tmin=-1,tmax=t0=0.; tmin=-1,t0=0,tmax=1)
        """
        control_points = np.array([[0.1, 0.2, 0.2]])
        momenta = np.array([[1., 0.2, 0.]])

        control_points_torch = Variable(
            torch.from_numpy(control_points).type(
                Settings().tensor_scalar_type))
        momenta_torch = Variable(
            torch.from_numpy(momenta).type(Settings().tensor_scalar_type))

        geodesic = Geodesic()
        geodesic.set_kernel(kernel_factory.factory('torch', 0.01))
        geodesic.set_use_rk2(True)
        geodesic.concentration_of_time_points = 10
        geodesic.set_momenta_t0(momenta_torch)
        geodesic.set_control_points_t0(control_points_torch)

        geodesic.set_tmin(-1.)
        geodesic.set_tmax(0.)
        geodesic.set_t0(0.)
        geodesic.update()

        cp_traj = geodesic._get_control_points_trajectory()
        mom_traj = geodesic._get_momenta_trajectory()
        times_traj = geodesic._get_times()

        self.assertTrue(len(cp_traj) == len(mom_traj))
        self.assertTrue(len(times_traj) == len(cp_traj))

        for (cp, mom, time) in zip(cp_traj, mom_traj, times_traj):
            self.assertTrue(
                np.allclose(cp.detach().numpy(),
                            control_points + time * momenta))
            self.assertTrue(np.allclose(mom.detach().numpy(), momenta))

        geodesic.set_tmin(-1.)
        geodesic.set_tmax(0.)
        geodesic.set_t0(0.)
        geodesic.update()

        cp_traj = geodesic._get_control_points_trajectory()
        mom_traj = geodesic._get_momenta_trajectory()
        times_traj = geodesic._get_times()

        self.assertTrue(len(cp_traj) == len(mom_traj))
        self.assertTrue(len(times_traj) == len(cp_traj))

        for (cp, mom, time) in zip(cp_traj, mom_traj, times_traj):
            # print(time, cp.detach().numpy(), control_points + time * momenta)
            self.assertTrue(
                np.allclose(cp.detach().numpy(),
                            control_points + time * momenta))
            self.assertTrue(np.allclose(mom.detach().numpy(), momenta))

        geodesic.set_tmin(-1.)
        geodesic.set_tmax(0.)
        geodesic.set_t0(0.)
        geodesic.update()

        cp_traj = geodesic._get_control_points_trajectory()
        mom_traj = geodesic._get_momenta_trajectory()
        times_traj = geodesic._get_times()

        self.assertTrue(len(cp_traj) == len(mom_traj))
        self.assertTrue(len(times_traj) == len(cp_traj))

        for (cp, mom, time) in zip(cp_traj, mom_traj, times_traj):
            self.assertTrue(
                np.allclose(cp.detach().numpy(),
                            control_points + time * momenta))
            self.assertTrue(np.allclose(mom.detach().numpy(), momenta))
Ejemplo n.º 17
0
 def test_cuda_kernel_factory_from_string(self):
     for k in ['keops']:
         logging.debug("testing kernel=", k)
         instance = kernel_factory.factory(k, kernel_width=1.)
         self.__isKernelValid(instance)
def _exp_parallelize(control_points, initial_momenta, projected_momenta,
                     xml_parameters):
    objects_list, objects_name, objects_name_extension, _, _ = create_template_metadata(
        xml_parameters.template_specifications)
    template = DeformableMultiObject()
    template.object_list = objects_list
    template.update()

    template_data = template.get_points()
    template_data_torch = Variable(
        torch.from_numpy(template_data).type(Settings().tensor_scalar_type))

    geodesic = Geodesic()
    geodesic.concentration_of_time_points = xml_parameters.concentration_of_time_points
    geodesic.set_kernel(
        kernel_factory.factory(xml_parameters.deformation_kernel_type,
                               xml_parameters.deformation_kernel_width))
    geodesic.set_use_rk2(xml_parameters.use_rk2)

    # Those are mandatory parameters.
    assert xml_parameters.tmin != -float(
        "inf"), "Please specify a minimum time for the geodesic trajectory"
    assert xml_parameters.tmax != float(
        "inf"), "Please specify a maximum time for the geodesic trajectory"

    geodesic.tmin = xml_parameters.tmin
    geodesic.tmax = xml_parameters.tmax
    if xml_parameters.t0 is None:
        geodesic.t0 = geodesic.tmin
    else:
        geodesic.t0 = xml_parameters.t0

    geodesic.set_momenta_t0(initial_momenta)
    geodesic.set_control_points_t0(control_points)
    geodesic.set_template_data_t0(template_data_torch)
    geodesic.update()

    # We write the flow of the geodesic

    geodesic.write("Regression", objects_name, objects_name_extension,
                   template)

    # Now we transport!
    parallel_transport_trajectory = geodesic.parallel_transport(
        projected_momenta)

    # Getting trajectory caracteristics:
    times = geodesic._get_times()
    control_points_traj = geodesic._get_control_points_trajectory()
    momenta_traj = geodesic._get_momenta_trajectory()
    template_data_traj = geodesic._get_template_data_trajectory()

    exponential = Exponential()
    exponential.number_of_time_points = xml_parameters.number_of_time_points
    exponential.set_kernel(
        kernel_factory.factory(xml_parameters.deformation_kernel_type,
                               xml_parameters.deformation_kernel_width))
    exponential.set_use_rk2(xml_parameters.use_rk2)

    # We save this trajectory, and the corresponding shape trajectory
    for i, (time, cp, mom, transported_mom, td) in enumerate(
            zip(times, control_points_traj, momenta_traj,
                parallel_transport_trajectory, template_data_traj)):
        # Writing the momenta/cps
        write_2D_array(
            cp.data.numpy(),
            "control_Points_tp_" + str(i) + "__age_" + str(time) + ".txt")
        write_3D_array(mom.data.numpy(),
                       "momenta_tp_" + str(i) + "__age_" + str(time) + ".txt")
        write_3D_array(
            transported_mom.data.numpy(),
            "transported_momenta_tp_" + str(i) + "__age_" + str(time) + ".txt")

        # Shooting from the geodesic:
        exponential.set_initial_template_data(td)
        exponential.set_initial_control_points(cp)
        exponential.set_initial_momenta(transported_mom)
        exponential.update()

        # Uncomment for massive writing, useful for debugging.
        # dir = "exp_"+str(i)+"_"+str(time)
        # if not(os.path.isdir(os.path.join(Settings().output_dir, dir))):
        #     os.mkdir(os.path.join(Settings().output_dir, dir))
        # exponential.write_flow([os.path.join(dir, elt) for elt in objects_name],
        #                        objects_name_extension,
        #                        template)
        # exponential.write_control_points_and_momenta_flow(os.path.join(dir, "cp_and_mom"))

        parallel_td = exponential.get_template_data()
        template.set_points(parallel_td)
        names = [
            objects_name[k] + "_parallel_curve_tp_" + str(i) + "__age_" +
            str(time) + "_" + objects_name_extension[k]
            for k in range(len(objects_name))
        ]
        template.write(names)
deformation_kernel_width = 5
number_of_time_points = 16
indexes = [0, 5, 10, 15]

images = []
for d in range(1):
    print('d = %d' % d)
    images_d = []

    latent_position = np.zeros((1, 3))
    latent_position[d] = 1
    momentum = np.dot(latent_position, principal_directions__pga).reshape(control_points__pga.shape)

    b_exponential = Exponential(
        kernel=kernel_factory.factory('keops', deformation_kernel_width),
        number_of_time_points=number_of_time_points,
        initial_control_points=torch.from_numpy(control_points__pga).float().cuda(),
        initial_momenta=torch.from_numpy(- momentum).float().cuda(),
        initial_template_points={'image_points': global_image_points})
    b_exponential.update()

    f_exponential = Exponential(
        kernel=kernel_factory.factory('keops', deformation_kernel_width),
        number_of_time_points=number_of_time_points,
        initial_control_points=torch.from_numpy(control_points__pga).float().cuda(),
        initial_momenta=torch.from_numpy(momentum).float().cuda(),
        initial_template_points={'image_points': global_image_points})
    f_exponential.update()

    images_d.append(
Ejemplo n.º 20
0
 def setUp(self):
     Settings().tensor_scalar_type = torch.FloatTensor
     self.kernel = kernel_factory.factory('torch', 10.)
     # self.kernel = kernel_factory.factory('keops', 10.)  # Duplicate the tests for both kernels ?
     self.multi_attach = MultiObjectAttachment()
Ejemplo n.º 21
0
 def test_cuda_kernel_factory(self):
     for k in [kernel_factory.Type.KEOPS, kernel_factory.Type.TORCH_CUDA]:
         logging.debug("testing kernel=", k)
         instance = kernel_factory.factory(k, kernel_width=1.)
         self.__isKernelValid(instance)
Ejemplo n.º 22
0
 def test_unknown_kernel_string(self):
     with self.assertRaises(TypeError):
         kernel_factory.factory('unknown_type')
def compute_parallel_transport(xml_parameters):
    """
    Takes as input an observation, a set of cp and mom which define the main geodesic, and another set of cp and mom describing the registration.
    Exp-parallel and geodesic-parallel are the two possible modes.
    """

    assert not xml_parameters.initial_control_points is None, "Please provide initial control points"
    assert not xml_parameters.initial_momenta is None, "Please provide initial momenta"
    assert not xml_parameters.initial_momenta_to_transport is None, "Please provide initial momenta to transport"

    control_points = read_2D_array(xml_parameters.initial_control_points)
    initial_momenta = read_3D_array(xml_parameters.initial_momenta)
    initial_momenta_to_transport = read_3D_array(
        xml_parameters.initial_momenta_to_transport)

    kernel = kernel_factory.factory(kernel_factory.Type.TorchKernel,
                                    xml_parameters.deformation_kernel_width)

    if xml_parameters.initial_control_points_to_transport is None:
        msg = "initial-control-points-to-transport was not specified, I am assuming they are the same as initial-control-points"
        warnings.warn(msg)
        control_points_to_transport = control_points
        need_to_project_initial_momenta = False
    else:
        control_points_to_transport = read_2D_array(
            xml_parameters.initial_control_points_to_transport)
        need_to_project_initial_momenta = True

    control_points_torch = Variable(
        torch.from_numpy(control_points).type(Settings().tensor_scalar_type))
    initial_momenta_torch = Variable(
        torch.from_numpy(initial_momenta).type(Settings().tensor_scalar_type))
    initial_momenta_to_transport_torch = Variable(
        torch.from_numpy(initial_momenta_to_transport).type(
            Settings().tensor_scalar_type))

    # We start by projecting the initial momenta if they are not carried at the right control points.

    if need_to_project_initial_momenta:
        control_points_to_transport_torch = Variable(
            torch.from_numpy(control_points_to_transport).type(
                Settings().tensor_scalar_type))
        velocity = kernel.convolve(control_points_torch,
                                   control_points_to_transport_torch,
                                   initial_momenta_to_transport_torch)
        kernel_matrix = kernel.get_kernel_matrix(control_points_torch)
        cholesky_kernel_matrix = torch.potrf(kernel_matrix)
        # cholesky_kernel_matrix = Variable(torch.Tensor(np.linalg.cholesky(kernel_matrix.data.numpy())).type_as(kernel_matrix))#Dirty fix if pytorch fails.
        projected_momenta = torch.potrs(
            velocity, cholesky_kernel_matrix).squeeze().contiguous()

    else:
        projected_momenta = initial_momenta_to_transport_torch

    if xml_parameters.use_exp_parallelization in [None, True]:
        _exp_parallelize(control_points_torch, initial_momenta_torch,
                         projected_momenta, xml_parameters)

    else:
        _geodesic_parallelize(control_points_torch, initial_momenta_torch,
                              projected_momenta, xml_parameters)
Ejemplo n.º 24
0
 def setUp(self):
     super().setUp()
     self.kernel_instance = kernel_factory.factory(
         kernel_factory.Type.KEOPS, kernel_width=1.)
Ejemplo n.º 25
0
 def test_non_cuda_kernel_factory(self):
     for k in [kernel_factory.Type.NO_KERNEL, kernel_factory.Type.TORCH]:
         logging.debug("testing kernel=", k)
         instance = kernel_factory.factory(k, kernel_width=1.)
         self.__isKernelValid(instance)
Ejemplo n.º 26
0
def instantiate_geodesic_regression_model(xml_parameters,
                                          dataset=None,
                                          ignore_noise_variance=False):
    model = GeodesicRegression()

    # Deformation object -----------------------------------------------------------------------------------------------
    model.geodesic.set_kernel(
        kernel_factory.factory(xml_parameters.deformation_kernel_type,
                               xml_parameters.deformation_kernel_width))
    model.geodesic.concentration_of_time_points = xml_parameters.concentration_of_time_points
    model.geodesic.t0 = xml_parameters.t0
    model.geodesic.set_use_rk2(xml_parameters.use_rk2)

    # Initial fixed effects --------------------------------------------------------------------------------------------
    # Template.
    model.freeze_template = xml_parameters.freeze_template  # this should happen before the init of the template and the cps
    model.initialize_template_attributes(
        xml_parameters.template_specifications)
    model.use_sobolev_gradient = xml_parameters.use_sobolev_gradient
    model.smoothing_kernel_width = xml_parameters.deformation_kernel_width * xml_parameters.sobolev_kernel_width_ratio

    # Control points.
    model.freeze_control_points = xml_parameters.freeze_control_points
    if xml_parameters.initial_control_points is not None:
        control_points = read_2D_array(xml_parameters.initial_control_points)
        print(">> Reading " + str(len(control_points)) +
              " initial control points from file " +
              xml_parameters.initial_control_points)
        model.set_control_points(control_points)
    else:
        model.initial_cp_spacing = xml_parameters.initial_cp_spacing

    # Momenta.
    if xml_parameters.initial_momenta is not None:
        momenta = read_3D_array(xml_parameters.initial_momenta)
        print('>> Reading initial momenta from file: ' +
              xml_parameters.initial_momenta)
        model.set_momenta(momenta)

    # Final initialization steps by the model object itself ------------------------------------------------------------
    model.update()

    # Special case of the noise variance hyperparameter ----------------------------------------------------------------
    # Compute residuals if needed.
    if not ignore_noise_variance and np.min(model.objects_noise_variance) < 0:

        template_data_torch, template_points_torch, control_points_torch, momenta_torch \
            = model._fixed_effects_to_torch_tensors(False)
        target_times = dataset.times[0]
        target_objects = dataset.deformable_objects[0]

        model.geodesic.set_tmin(min(target_times))
        model.geodesic.set_tmax(max(target_times))
        model.geodesic.set_template_points_t0(template_points_torch)
        model.geodesic.set_control_points_t0(control_points_torch)
        model.geodesic.set_momenta_t0(momenta_torch)
        model.geodesic.update()

        residuals = np.zeros((model.number_of_objects, ))
        for (time, target) in zip(target_times, target_objects):
            deformed_points = model.geodesic.get_template_points(time)
            deformed_data = model.template.get_deformed_data(
                deformed_points, template_data_torch)
            residuals += model.multi_object_attachment.compute_distances(
                deformed_data, model.template, target).data.numpy()

        # Initialize the noise variance hyperparameter.
        for k, obj in enumerate(xml_parameters.template_specifications.keys()):
            if model.objects_noise_variance[k] < 0:
                nv = 0.01 * residuals[k] / float(len(target_times))
                model.objects_noise_variance[k] = nv
                print('>> Automatically chosen noise std: %.4f [ %s ]' %
                      (math.sqrt(nv), obj))

    # Return the initialized model.
    return model
def instantiate_deterministic_atlas_model(xml_parameters,
                                          dataset=None,
                                          ignore_noise_variance=False):
    model = DeterministicAtlas()

    # Deformation object -----------------------------------------------------------------------------------------------
    model.exponential.kernel = kernel_factory.factory(
        xml_parameters.deformation_kernel_type,
        xml_parameters.deformation_kernel_width)
    model.exponential.number_of_time_points = xml_parameters.number_of_time_points
    model.exponential.set_use_rk2(xml_parameters.use_rk2)

    # Initial fixed effects --------------------------------------------------------------------------------------------
    # Template.
    model.freeze_template = xml_parameters.freeze_template  # this should happen before the init of the template and the cps
    model.initialize_template_attributes(
        xml_parameters.template_specifications)
    model.use_sobolev_gradient = xml_parameters.use_sobolev_gradient
    model.smoothing_kernel_width = xml_parameters.deformation_kernel_width * xml_parameters.sobolev_kernel_width_ratio

    # Control points.
    model.freeze_control_points = xml_parameters.freeze_control_points
    if xml_parameters.initial_control_points is not None:
        control_points = read_2D_array(xml_parameters.initial_control_points)
        print(">> Reading " + str(len(control_points)) +
              " initial control points from file " +
              xml_parameters.initial_control_points)
        model.set_control_points(control_points)
    else:
        model.initial_cp_spacing = xml_parameters.initial_cp_spacing

    # Momenta.
    if xml_parameters.initial_momenta is not None:
        momenta = read_3D_array(xml_parameters.initial_momenta)
        print('>> Reading %d initial momenta from file: %s' %
              (momenta.shape[0], xml_parameters.initial_momenta))
        model.set_momenta(momenta)
        model.number_of_subjects = momenta.shape[0]
    else:
        model.number_of_subjects = len(xml_parameters.dataset_filenames)

    # Final initialization steps by the model object itself ------------------------------------------------------------
    model.update()

    # Special case of the noise variance hyperparameter ----------------------------------------------------------------
    # Compute residuals if needed.
    if not ignore_noise_variance and np.min(model.objects_noise_variance) < 0:

        template_data_torch, template_points_torch, control_points_torch, momenta_torch \
            = model._fixed_effects_to_torch_tensors(False)
        targets = dataset.deformable_objects
        targets = [target[0] for target in targets]

        residuals_torch = []
        model.exponential.set_initial_template_points(template_points_torch)
        model.exponential.set_initial_control_points(control_points_torch)
        for i, target in enumerate(targets):
            model.exponential.set_initial_momenta(momenta_torch[i])
            model.exponential.update()
            deformed_points = model.exponential.get_template_points()
            deformed_data = model.template.get_deformed_data(
                deformed_points, template_data_torch)
            residuals_torch.append(
                model.multi_object_attachment.compute_distances(
                    deformed_data, model.template, target))

        residuals = np.zeros((model.number_of_objects, ))
        for i in range(len(residuals_torch)):
            residuals += residuals_torch[i].data.numpy()

        # Initialize the noise variance hyperparameter.
        for k, obj in enumerate(xml_parameters.template_specifications.keys()):
            if model.objects_noise_variance[k] < 0:
                nv = 0.01 * residuals[k] / float(model.number_of_subjects)
                model.objects_noise_variance[k] = nv
                print('>> Automatically chosen noise std: %.4f [ %s ]' %
                      (math.sqrt(nv), obj))

    # Return the initialized model.
    return model