def get_input(self, moving_momentum_path_list, moving_name,
                  init_weight_path_list):
        """
        each line include the path of moving, the path of label(None if not exists)
        :return:
        """

        fr_sitk = lambda x: torch.Tensor(
            sitk.GetArrayFromImage(sitk.ReadImage(x))).cuda()
        moving = fr_sitk(moving_momentum_path_list[0])[None][None]
        l_moving = None
        if moving_momentum_path_list[1] is not None:
            l_moving = fr_sitk(moving_momentum_path_list[1])[None][None]
        if moving_name is None:
            moving_name = get_file_name(moving_momentum_path_list[0])
        if self.resize_output != [-1., -1, -1]:
            moving, _ = resample_image(moving, [1, 1, 1],
                                       desiredSize=[1, 1] + self.resize_output,
                                       spline_order=1,
                                       zero_boundary=True)
            if moving_momentum_path_list[1] is not None:
                l_moving, _ = resample_image(l_moving, [1, 1, 1],
                                             desiredSize=[1, 1] +
                                             self.resize_output,
                                             spline_order=0,
                                             zero_boundary=True)
        return moving, l_moving, moving_name
Beispiel #2
0
def compute_warped_image_label(img_label_txt_pth,phi_pth,phi_type, saving_pth):
    img_label_pth_list = read_txt_into_list(img_label_txt_pth)
    phi_pth_list = glob(os.path.join(phi_pth,phi_type))
    f = lambda pth: sitk.GetArrayFromImage(sitk.ReadImage(pth))
    img_list = [f(pth[0]) for pth in img_label_pth_list]
    label_list = [f(pth[1]) for pth in img_label_pth_list]
    num_img = len(img_list)
    for i in range(num_img):
        fname = get_file_name(img_label_pth_list[i][0])
        img = torch.Tensor(img_list[i][None][None])
        label = torch.Tensor(label_list[i][None][None])
        f_phi = lambda x: get_file_name(x).find(fname)==0
        phi_sub_list = list(filter(f_phi, phi_pth_list))
        num_aug = len(phi_sub_list)
        phi_list = [f(pth) for pth in phi_sub_list]
        img = img.repeat(num_aug,1,1,1,1)
        label = label.repeat(num_aug,1,1,1,1)
        phi = np.stack(phi_list,0)
        phi = np.transpose(phi,(0,4,3,2,1))
        phi = torch.Tensor(phi)
        sz = np.array(img.shape[2:])
        spacing = 1./(sz-1)
        phi, _ = resample_image(phi,spacing,[1,3]+list(img.shape[2:]))
        warped_img = compute_warped_image_multiNC(img,phi,spacing,spline_order=1,zero_boundary=True)
        warped_label = compute_warped_image_multiNC(label,phi,spacing,spline_order=0,zero_boundary=True)
        save_image_with_given_reference(warped_img,[img_label_pth_list[i][0]]*num_aug,saving_pth,[get_file_name(pth).replace("_phi","")+'_warped' for pth in phi_sub_list])
        save_image_with_given_reference(warped_label,[img_label_pth_list[i][0]]*num_aug,saving_pth,[get_file_name(pth).replace("_phi","")+'_label' for pth in phi_sub_list])
    def get_input(self, moving_momentum_path_list, fname_list,
                  init_weight_path_list):
        """ each line include the path of moving, the path of label (None if not exist), path of momentum1, momentum2...."""

        fr_sitk = lambda x: torch.Tensor(
            sitk.GetArrayFromImage(sitk.ReadImage(x))).cuda()
        moving = fr_sitk(moving_momentum_path_list[0])[None][None]
        l_moving = None
        if moving_momentum_path_list[1] is not None:
            l_moving = fr_sitk(moving_momentum_path_list[1])[None][None]
        momentum_list = [((fr_sitk(path)).permute(3, 2, 1, 0))[None]
                         for path in moving_momentum_path_list[2:]]

        if init_weight_path_list is not None:
            init_weight_list = [[
                fr_sitk(path) for path in init_weight_path_list
            ]]
        else:
            init_weight_list = None
        if fname_list is None:
            moving_name = get_file_name(moving_momentum_path_list[0])
            target_name_list = [
                get_file_name(path) for path in moving_momentum_path_list[2:]
            ]
            target_name_list = [
                fname.replace("_0000_Momentum", '')
                for fname in target_name_list
            ]
        else:
            moving_name = fname_list[0]
            target_name_list = fname_list[1:]
        if self.resize_output != [-1., -1, -1]:
            moving, _ = resample_image(moving, [1, 1, 1],
                                       desiredSize=[1, 1] + self.resize_output,
                                       spline_order=1,
                                       zero_boundary=True)
            if moving_momentum_path_list[1] is not None:
                l_moving, _ = resample_image(l_moving, [1, 1, 1],
                                             desiredSize=[1, 1] +
                                             self.resize_output,
                                             spline_order=0,
                                             zero_boundary=True)
        return moving, l_moving, momentum_list, init_weight_list, moving_name, target_name_list
    def get_input(self, moving_momentum_path_list, fname_list,
                  init_weight_path_list):
        """
        each line includes  path of moving, path of moving label(None if not exists), path of mom_1,...mom_m, affine_1....affine_m
        """

        fr_sitk = lambda x: torch.Tensor(
            sitk.GetArrayFromImage(sitk.ReadImage(x))).cuda()

        moving = fr_sitk(moving_momentum_path_list[0])[None][None]
        l_moving = None
        if moving_momentum_path_list[1] is not None:
            l_moving = fr_sitk(moving_momentum_path_list[1])[None][None]
        num_m = int((len(moving_momentum_path_list) - 2) / 2)
        momentum_list = [
            fr_sitk(path).permute(3, 2, 1, 0)[None]
            for path in moving_momentum_path_list[2:num_m + 2]
        ]
        #affine_list =[fr_sitk(path).permute(3,2,1,0)[None] for path in moving_momentum_path_list[num_m+2:]]
        affine_forward_inverse_list = [
            self.read_affine_param_and_output_map(path, moving.shape[2:])
            for path in moving_momentum_path_list[num_m + 2:]
        ]
        affine_list = [
            forward_inverse[0]
            for forward_inverse in affine_forward_inverse_list
        ]
        inverse_affine_list = [
            forward_inverse[1]
            for forward_inverse in affine_forward_inverse_list
        ]

        if init_weight_path_list is not None:
            init_weight_list = [[
                fr_sitk(path) for path in init_weight_path_list
            ]]
        else:
            init_weight_list = None
        if fname_list is None:
            moving_name = get_file_name(moving_momentum_path_list[0])
            target_name_list = [
                get_file_name(path)
                for path in moving_momentum_path_list[2:num_m + 2]
            ]
            target_name_list = [
                fname.replace("_0000_Momentum", '')
                for fname in target_name_list
            ]
        else:
            moving_name = fname_list[0]
            target_name_list = fname_list[1:]

        if self.resize_output != [-1., -1, -1]:
            moving, _ = resample_image(moving, [1, 1, 1],
                                       desiredSize=[1, 1] + self.resize_output,
                                       spline_order=1,
                                       zero_boundary=True)
            if moving_momentum_path_list[1] is not None:
                l_moving, _ = resample_image(l_moving, [1, 1, 1],
                                             desiredSize=[1, 1] +
                                             self.resize_output,
                                             spline_order=0,
                                             zero_boundary=True)

        return moving, l_moving, momentum_list, init_weight_list, affine_list, inverse_affine_list, moving_name, target_name_list
    def generate_single_res(self, moving, l_moving, momentum, init_weight,
                            initial_map, initial_inverse_map, fname, t_aug,
                            output_path, moving_path):
        params = self.mermaid_setting
        params['model']['registration_model']['forward_model']['tTo'] = t_aug

        # here we assume the momentum is computed at low_resol_factor=0.5
        if momentum is not None:
            input_img_sz = [1, 1] + [int(sz * 2) for sz in momentum.shape[2:]]
        else:
            input_img_sz = list(moving.shape)
            momentum_sz_low = [1, 3] + [
                int(dim / self.rand_momentum_shrink_factor)
                for dim in input_img_sz[2:]
            ]
            momentum_sz = [1, 3] + [int(dim / 2) for dim in input_img_sz[2:]]
            momentum = (np.random.rand(*momentum_sz_low) * 2 -
                        1) * self.magnitude
            mom_spacing = 1. / (np.array(momentum_sz_low[2:]) - 1)
            momentum = torch.Tensor(momentum).cuda()
            momentum, _ = resample_image(momentum,
                                         mom_spacing,
                                         momentum_sz,
                                         spline_order=1,
                                         zero_boundary=True)
        if self.resize_output != [-1, -1, -1]:
            momentum_sz = [1, 3] + [int(dim / 2) for dim in self.resize_output]
            mom_spacing = 1. / (np.array(momentum_sz[2:]) - 1)
            momentum, _ = resample_image(momentum,
                                         mom_spacing,
                                         momentum_sz,
                                         spline_order=1,
                                         zero_boundary=True)
            input_img_sz = [1, 1] + [int(sz * 2) for sz in momentum.shape[2:]]

        org_spacing = 1.0 / (np.array(moving.shape[2:]) - 1)
        input_spacing = 1.0 / (np.array(input_img_sz[2:]) - 1)
        size_diff = not input_img_sz == list(moving.shape)
        if size_diff:
            input_img, _ = resample_image(moving, org_spacing, input_img_sz)
        else:
            input_img = moving
        low_initial_map = None
        low_init_inverse_map = None
        if initial_map is not None:
            low_initial_map, _ = resample_image(
                initial_map, input_spacing, [1, 3] + list(momentum.shape[2:]))
        if initial_inverse_map is not None:
            low_init_inverse_map, _ = resample_image(initial_inverse_map,
                                                     input_spacing, [1, 3] +
                                                     list(momentum.shape[2:]))
        individual_parameters = dict(m=momentum, local_weights=init_weight)
        sz = np.array(input_img.shape)
        extra_info = None
        visual_param = None
        res = evaluate_model(input_img,
                             input_img,
                             sz,
                             input_spacing,
                             use_map=True,
                             compute_inverse_map=self.compute_inverse,
                             map_low_res_factor=0.5,
                             compute_similarity_measure_at_low_res=False,
                             spline_order=1,
                             individual_parameters=individual_parameters,
                             shared_parameters=None,
                             params=params,
                             extra_info=extra_info,
                             visualize=False,
                             visual_param=visual_param,
                             given_weight=False,
                             init_map=initial_map,
                             lowres_init_map=low_initial_map,
                             init_inverse_map=initial_inverse_map,
                             lowres_init_inverse_map=low_init_inverse_map)
        phi = res[1]
        phi_new = phi
        if size_diff:
            phi_new, _ = resample_image(phi, input_spacing,
                                        [1, 3] + list(moving.shape[2:]))
        warped = compute_warped_image_multiNC(moving,
                                              phi_new,
                                              org_spacing,
                                              spline_order=1,
                                              zero_boundary=True)
        if initial_inverse_map is not None and self.affine_back_to_original_postion:
            # here we take zero boundary boundary which need two step image interpolation
            warped = compute_warped_image_multiNC(warped,
                                                  initial_inverse_map,
                                                  org_spacing,
                                                  spline_order=1,
                                                  zero_boundary=True)
            phi_new = compute_warped_image_multiNC(phi_new,
                                                   initial_inverse_map,
                                                   org_spacing,
                                                   spline_order=1)
        save_image_with_given_reference(warped, [moving_path], output_path,
                                        [fname + '_image'])
        if l_moving is not None:
            # we assume the label doesnt lie at the boundary
            l_warped = compute_warped_image_multiNC(l_moving,
                                                    phi_new,
                                                    org_spacing,
                                                    spline_order=0,
                                                    zero_boundary=True)
            save_image_with_given_reference(l_warped, [moving_path],
                                            output_path, [fname + '_label'])

        if self.save_tf_map:
            save_deformation(phi_new, output_path, [fname + '_phi_map'])
            if self.compute_inverse:
                phi_inv = res[2]
                inv_phi_new = phi_inv
                if self.affine_back_to_original_postion:
                    print(
                        "Cannot compute the inverse map when affine back to the source image position"
                    )
                    return
                if size_diff:
                    inv_phi_new, _ = resample_image(phi_inv, input_spacing,
                                                    [1, 3] +
                                                    list(moving.shape[2:]))
                save_deformation(inv_phi_new, output_path,
                                 [fname + '_inv_map'])