Esempio n. 1
0
def compute_warped_image_label(img_label_txt_pth,phi_pth,phi_type, saving_pth):
    img_label_pth_list = read_txt_into_list(img_label_txt_pth)
    phi_pth_list = glob(os.path.join(phi_pth,phi_type))
    f = lambda pth: sitk.GetArrayFromImage(sitk.ReadImage(pth))
    img_list = [f(pth[0]) for pth in img_label_pth_list]
    label_list = [f(pth[1]) for pth in img_label_pth_list]
    num_img = len(img_list)
    for i in range(num_img):
        fname = get_file_name(img_label_pth_list[i][0])
        img = torch.Tensor(img_list[i][None][None])
        label = torch.Tensor(label_list[i][None][None])
        f_phi = lambda x: get_file_name(x).find(fname)==0
        phi_sub_list = list(filter(f_phi, phi_pth_list))
        num_aug = len(phi_sub_list)
        phi_list = [f(pth) for pth in phi_sub_list]
        img = img.repeat(num_aug,1,1,1,1)
        label = label.repeat(num_aug,1,1,1,1)
        phi = np.stack(phi_list,0)
        phi = np.transpose(phi,(0,4,3,2,1))
        phi = torch.Tensor(phi)
        sz = np.array(img.shape[2:])
        spacing = 1./(sz-1)
        phi, _ = resample_image(phi,spacing,[1,3]+list(img.shape[2:]))
        warped_img = compute_warped_image_multiNC(img,phi,spacing,spline_order=1,zero_boundary=True)
        warped_label = compute_warped_image_multiNC(label,phi,spacing,spline_order=0,zero_boundary=True)
        save_image_with_given_reference(warped_img,[img_label_pth_list[i][0]]*num_aug,saving_pth,[get_file_name(pth).replace("_phi","")+'_warped' for pth in phi_sub_list])
        save_image_with_given_reference(warped_label,[img_label_pth_list[i][0]]*num_aug,saving_pth,[get_file_name(pth).replace("_phi","")+'_label' for pth in phi_sub_list])
Esempio n. 2
0
    def get_warped_label_map(self, label_map, phi, sched='nn', use_01=False):
        """
        get warped label map

        :param label_map: label map to warp
        :param phi: transformation map
        :param sched: 'nn' neareast neighbor
        :param use_01: indicate the input phi is in [0,1] coord; else  the phi is assumed to be [-1,1]
        :return: the warped label map
        """
        if sched == 'nn':
            ###########TODO fix with new cuda interface,  now comment for torch1 compatability
            # try:
            #     print(" the cuda nn interpolation is used")
            #     warped_label_map = get_nn_interpolation(label_map, phi)
            # except:
            #     warped_label_map = compute_warped_image_multiNC(label_map,phi,self.spacing,spline_order=0,zero_boundary=True,use_01_input=use_01)
            warped_label_map = compute_warped_image_multiNC(
                label_map,
                phi,
                self.spacing,
                spline_order=0,
                zero_boundary=True,
                use_01_input=use_01)
            # check if here should be add assert
            assert abs(
                torch.sum(warped_label_map.detach() -
                          warped_label_map.detach().round())
            ) < 0.1, "nn interpolation is not precise"
        else:
            raise ValueError(" the label warpping method is not implemented")
        return warped_label_map
Esempio n. 3
0
def analyze_on_single_res(pair,pair_name, expr_folder=None, color_image=False,model_name='rdmm'):
    moving, target, spacing, moving_init_weight, phi,m = get_analysis_input(pair,expr_folder,pair_name,color_image=color_image,model_name=model_name)
    lmoving, ltarget =get_labeled_image(pair)
    params = pars.ParameterDict()
    params.load_JSON(os.path.join(expr_folder,'mermaid_setting.json'))
    individual_parameters = dict(m=m,local_weights=moving_init_weight)
    sz = np.array(moving.shape)
    saving_folder = os.path.join(expr_folder, 'analysis')
    saving_folder = os.path.join(saving_folder, pair_name)
    saving_folder = os.path.join(saving_folder,'res_analysis')
    os.makedirs(saving_folder,exist_ok=True)
    extra_info = None
    visual_param = None

    extra_info = {'fname':[pair_name],'saving_folder':[saving_folder]}
    visual_param = setting_visual_saving(expr_folder, pair_name,folder_name='color')

    res= evaluate_model(moving, target, sz, spacing,
                   use_map=True,
                   compute_inverse_map=False,
                   map_low_res_factor=0.5,
                   compute_similarity_measure_at_low_res=False,
                   spline_order=1,
                   individual_parameters=individual_parameters,
                   shared_parameters=None, params=params, extra_info=extra_info,visualize=False,visual_param=visual_param, given_weight=True)
    phi = res[1]
    lres = utils.compute_warped_image_multiNC(lmoving, phi, spacing, 0, zero_boundary=True)
    scores = get_multi_metric(lres,ltarget,rm_bg=True)
    avg_jacobi = compute_jacobi(phi,spacing)
    return scores['label_batch_avg_res']['dice'], avg_jacobi
Esempio n. 4
0
def read_image_and_map_and_apply_map(image_filename,map_filename):
    """
    Reads an image and a map and applies the map to an image
    :param image_filename: input image filename
    :param map_filename: input map filename
    :return: the warped image and its image header as a tupe (im,hdr)
    """

    im_warped = None
    map,map_hdr = fileio.MapIO().read(map_filename)
    im,hdr,_,_ = fileio.ImageIO().read_to_map_compatible_format(image_filename,map)

    spacing = hdr['spacing']
    #TODO: check that the spacing is compatible with the map

    if (im is not None) and (map is not None):
        # make pytorch arrays for subsequent processing
        im_t = AdaptVal(torch.from_numpy(im))
        map_t = AdaptVal(torch.from_numpy(map))
        im_warped = utils.t2np( utils.compute_warped_image_multiNC(im_t,map_t,spacing) )

        return im_warped,hdr
    else:
        print('Could not read map or image')
        return None,None
Esempio n. 5
0
def compute_atlas_label(lsource_folder_path, to_atlas_folder_pth, atlas_type,
                        atlas_to_l_switcher, output_folder):
    to_atlas_pth_list = glob(
        os.path.join(to_atlas_folder_pth, "*" + atlas_type))[:100]
    to_atlas_name_list = [
        get_file_name(to_atlas_pth) for to_atlas_pth in to_atlas_pth_list
    ]
    l_pth_list = [
        os.path.join(lsource_folder_path,
                     name.replace(*atlas_to_l_switcher) + '.nii.gz')
        for name in to_atlas_name_list
    ]
    fr_sitk = lambda x: sitk.GetArrayFromImage(sitk.ReadImage(x))
    l_list = [fr_sitk(pth)[None] for pth in l_pth_list]
    to_atlas_list = [np.transpose(fr_sitk(pth)) for pth in to_atlas_pth_list]
    l = np.stack(l_list).astype(np.int64)
    num_c = len(np.unique(l))
    to_atlas = np.stack(to_atlas_list)
    l = torch.LongTensor(l)
    to_atlas = torch.Tensor(to_atlas)
    l_onehot = make_one_hot(l, C=num_c)
    spacing = 1. / (np.array(l.shape[2:]) - 1)
    l_onehot = l_onehot.to(torch.float32)
    warped_one_hot = compute_warped_image_multiNC(l_onehot,
                                                  to_atlas,
                                                  spacing=spacing,
                                                  spline_order=1,
                                                  zero_boundary=True)
    sum_one_hot = torch.sum(warped_one_hot, 0, keepdim=True)
    voting = torch.max(torch.Tensor(sum_one_hot), 1)[1][None].to(torch.float32)
    save_image_with_given_reference(voting, [l_pth_list[0]], output_folder,
                                    ["atlas_label"])
Esempio n. 6
0
    def get_circular_map(self):

        return pyreg_utils.compute_warped_image_multiNC(
            self.map.to(device), self.inverse_map.to(device),
            self.moving_spacing_normalized, 1,
            False).cpu().numpy().squeeze() / (
                self.moving_spacing_normalized.reshape(
                    (-1, ) + (1, ) * len(self.moving_spacing_normalized)))
Esempio n. 7
0
def resample_warped_phi_and_image(source_path, target_path, l_source_path,
                                  l_target_path, phi, spacing):
    new_phi = None
    warped = None
    l_warped = None
    new_spacing = None
    if source_path is not None:
        s = sitk.GetArrayFromImage(sitk.ReadImage(source_path)).astype(
            np.float32)
        t = sitk.GetArrayFromImage(sitk.ReadImage(target_path)).astype(
            np.float32)
        sz_t = [1, 1] + list(t.shape)
        source = torch.from_numpy(s[None][None]).to(phi.device)
        new_phi, new_spacing = resample_image(phi,
                                              spacing,
                                              sz_t,
                                              1,
                                              zero_boundary=True)
        warped = py_utils.compute_warped_image_multiNC(source,
                                                       new_phi,
                                                       new_spacing,
                                                       1,
                                                       zero_boundary=True)

    if l_source_path is not None:
        ls = sitk.GetArrayFromImage(sitk.ReadImage(l_source_path)).astype(
            np.float32)
        lt = sitk.GetArrayFromImage(sitk.ReadImage(l_target_path)).astype(
            np.float32)
        sz_lt = [1, 1] + list(lt.shape)
        l_source = torch.from_numpy(ls[None][None]).to(phi.device)
        if new_phi is None:
            new_phi, new_spacing = resample_image(phi,
                                                  spacing,
                                                  sz_lt,
                                                  1,
                                                  zero_boundary=True)
        l_warped = py_utils.compute_warped_image_multiNC(l_source,
                                                         new_phi,
                                                         new_spacing,
                                                         0,
                                                         zero_boundary=True)

    return new_phi, warped, l_warped, new_spacing
Esempio n. 8
0
def eval_on_dirlab_per_case(forward_map,inv_map, pair_name,moving, target,record_path):
    transform_shape = np.array(forward_map.shape[2:])
    slandmark_index,tlandmark_index, physical_spacing = get_landmark(pair_name,transform_shape)
    spacing = 1./(transform_shape-1)
    slandmark_img_coord = spacing*slandmark_index
    tlandmark_img_coord = spacing*tlandmark_index
    # target = target.squeeze().clone()
    # for coord in tlandmark_index:
    #     coord_int  = [int(c) for c in coord]
    #     target[coord_int[0],coord_int[1],coord_int[2]] = 10.
    # save_3D_img_from_numpy(target.detach().cpu().numpy().squeeze(),"/playpen-raid2/zyshen/debug/{}_debug.nii.gz".format(pair_name))


    tlandmark_img_coord_reshape = torch.Tensor(tlandmark_img_coord.transpose(1,0)).view([1,3,-1,1,1])
    tlandmark_img_coord_reshape = tlandmark_img_coord_reshape.to(forward_map.device)
    ts_landmark_img_coord = py_utils.compute_warped_image_multiNC(forward_map, tlandmark_img_coord_reshape*2-1, spacing, 1, zero_boundary=False,use_01_input=False)
    ts_landmark_img_coord = ts_landmark_img_coord.squeeze().transpose(1,0).detach().cpu().numpy()
    diff_ts = (slandmark_img_coord - ts_landmark_img_coord)/spacing*physical_spacing

    # target = target.squeeze().clone()
    # for coord in ts_landmark_img_coord:
    #     coord_int  = [int(c) for c in coord/spacing]
    #     target[coord_int[0],coord_int[1],coord_int[2]] = 10.
    # save_3D_img_from_numpy(target.detach().cpu().numpy().squeeze(),"/playpen-raid2/zyshen/debug/{}_debug_warped.nii.gz".format(pair_name))

    slandmark_img_coord_reshape = torch.Tensor(slandmark_img_coord.transpose(1, 0)).view([1, 3, -1, 1, 1])
    slandmark_img_coord_reshape = slandmark_img_coord_reshape.to(inv_map.device)
    st_landmark_img_coord = py_utils.compute_warped_image_multiNC(inv_map, slandmark_img_coord_reshape * 2 - 1,
                                                                  spacing, 1, zero_boundary=False, use_01_input=False)
    st_landmark_img_coord = st_landmark_img_coord.squeeze().transpose(1, 0).detach().cpu().numpy()
    landmark_saving_folder = os.path.join(record_path,"landmarks")
    os.makedirs(landmark_saving_folder, exist_ok=True)
    save_vtk(os.path.join(landmark_saving_folder,"{}_source.vtk".format(pair_name)),{"points":slandmark_img_coord})
    save_vtk(os.path.join(landmark_saving_folder,"{}_target.vtk".format(pair_name)),{"points":tlandmark_img_coord})
    save_vtk(os.path.join(landmark_saving_folder,"{}_target_warp_to_source.vtk".format(pair_name)),{"points":ts_landmark_img_coord})
    save_vtk(os.path.join(landmark_saving_folder,"{}_source_warp_to_target.vtk".format(pair_name)),{"points":st_landmark_img_coord})
    diff_st = (tlandmark_img_coord - st_landmark_img_coord) / spacing * physical_spacing

    return np.linalg.norm(diff_ts,ord=2,axis=1).mean(), np.linalg.norm(diff_st,ord=2,axis=1).mean()
Esempio n. 9
0
    def affine_forward(self, moving, target=None):
        if self.using_affine_init:
            with torch.no_grad():
                toaffine_moving, toaffine_target = moving, target
                resize_affine_input = all(
                    [sz != -1 for sz in self.affine_resoltuion[2:]])
                if resize_affine_input:
                    toaffine_moving = get_resampled_image(
                        toaffine_moving,
                        self.spacing,
                        self.affine_resoltuion,
                        identity_map=self.affineIdentityMap)
                    toaffine_target = get_resampled_image(
                        toaffine_target,
                        self.spacing,
                        self.affine_resoltuion,
                        identity_map=self.affineIdentityMap)
                affine_img, affine_map, affine_param = self.affine_net(
                    toaffine_moving, toaffine_target)
                self.affine_param = affine_param
                affine_map = (affine_map + 1) / 2.
                inverse_map = None
                if self.compute_inverse_map:
                    inverse_map = self.affine_net.get_inverse_map(use_01=True)
                if resize_affine_input:
                    affine_img = py_utils.compute_warped_image_multiNC(
                        moving,
                        affine_map,
                        self.spacing,
                        1,
                        zero_boundary=True,
                        use_01_input=True)
                if self.using_physical_coord:
                    for i in range(self.dim):
                        affine_map[:, i] = affine_map[:, i] * self.spacing[
                            i] / self.standard_spacing[i]
                    if self.compute_inverse_map:
                        for i in range(self.dim):
                            inverse_map[:,
                                        i] = inverse_map[:, i] * self.spacing[
                                            i] / self.standard_spacing[i]
                self.inverse_map = inverse_map
        else:
            num_b = moving.shape[0]
            affine_map = self.identityMap[:num_b].clone()
            if self.compute_inverse_map:
                self.inverse_map = self.identityMap[:num_b].clone()

            affine_img = moving
        return affine_img, affine_map
Esempio n. 10
0
def resample_warped_phi_and_image(source_path_list, l_source_path_list, phi,
                                  spacing):
    num_s = len(source_path_list)
    s = [sitk.GetArrayFromImage(sitk.ReadImage(f)) for f in source_path_list]
    sz = [num_s, 1] + list(s[0].shape)
    source = np.stack(s, axis=0)
    source = source.reshape(*sz)
    source = MyTensor(source)

    if l_source_path_list is not None:
        ls = [
            sitk.GetArrayFromImage(sitk.ReadImage(f))
            for f in l_source_path_list
        ]
        sz = [num_s, 1] + list(ls[0].shape)
        l_source = np.stack(ls, axis=0)
        l_source = l_source.reshape(*sz)
        l_source = MyTensor(l_source)

    new_phi, new_spacing = resample_image(phi,
                                          spacing,
                                          sz,
                                          1,
                                          zero_boundary=True)
    warped = py_utils.compute_warped_image_multiNC(source,
                                                   new_phi,
                                                   new_spacing,
                                                   1,
                                                   zero_boundary=True)
    l_warped = None
    if l_source_path_list is not None:
        l_warped = py_utils.compute_warped_image_multiNC(l_source,
                                                         new_phi,
                                                         new_spacing,
                                                         0,
                                                         zero_boundary=True)
    return new_phi, warped, l_warped, new_spacing
Esempio n. 11
0
def compute_reg_baseline(l_path_list, atlas_label_path, l_phi_path_list):
    fr_sitk = lambda x: sitk.GetArrayFromImage(sitk.ReadImage(x))
    l_list = [fr_sitk(pth)[None] for pth in l_path_list]
    atlas_label = [fr_sitk(atlas_label_path)[None]] * len(l_list)
    to_atlas_list = [np.transpose(fr_sitk(pth)) for pth in l_phi_path_list]
    atlas_label = torch.Tensor(np.stack(atlas_label))
    l = torch.Tensor(np.stack(l_list))
    to_atlas = np.stack(to_atlas_list)
    to_atlas = torch.Tensor(to_atlas)
    spacing = 1. / (np.array(l.shape[2:]) - 1)
    warped_label = compute_warped_image_multiNC(atlas_label,
                                                to_atlas,
                                                spacing=spacing,
                                                spline_order=0,
                                                zero_boundary=True)
    res = get_multi_metric(warped_label, l)
    average_dice, detailed_dice = np.mean(
        res['batch_avg_res']['dice'][0, 1:]), res['batch_avg_res']['dice']
    print(average_dice)
Esempio n. 12
0
def save_jacobi_map(map, img_sz, fname, output_path, save_neg_jacobi=True):
    img_sz = np.array(img_sz)
    map_sz = np.array(map.shape[2:])
    spacing = 1. / (np.array(img_sz) - 1)  # the disp coorindate is [-1,1]

    need_resampling = not all(list(img_sz == map_sz))
    if need_resampling:
        id = py_utils.identity_map_multiN(img_sz, spacing)
        map = py_utils.compute_warped_image_multiNC(map,
                                                    id,
                                                    spacing,
                                                    1,
                                                    zero_boundary=False)
    map = map.detach().cpu().numpy()

    fd = fdt.FD_np(spacing)
    dfx = fd.dXc(map[:, 0, ...])
    dfy = fd.dYc(map[:, 1, ...])
    dfz = fd.dZc(map[:, 2, ...])
    jacobi_det = dfx * dfy * dfz
    # self.temp_save_Jacobi_image(jacobi_det,map)
    jacobi_neg_bool = jacobi_det < 0.
    jacobi_neg = jacobi_det[jacobi_neg_bool]
    jacobi_abs = np.abs(jacobi_det)
    jacobi_abs_scalar = -np.sum(jacobi_neg)  #
    jacobi_num_scalar = np.sum(jacobi_neg_bool)
    print("fname:{}  folds for each channel {},{},{}".format(
        fname, np.sum(dfx < 0.), np.sum(dfy < 0.), np.sum(dfz < 0.)))
    print("fname:{} the jacobi_value of fold points  is {}".format(
        fname, jacobi_abs_scalar))
    print("fname:{} the number of fold points is {}".format(
        fname, jacobi_num_scalar))
    for i in range(jacobi_abs.shape[0]):
        if not save_neg_jacobi:
            jacobi_img = sitk.GetImageFromArray(jacobi_abs[i])
        else:
            jacobi_img = sitk.GetImageFromArray(jacobi_neg[i])
        pth = os.path.join(output_path, fname) + '.nii.gz'
        sitk.WriteImage(jacobi_img, pth)
    def warp_mesh(self, source_patient):
        mesh_list = source_patient.get_mesh_list()
        inv_map, img_sz = self.get_transform_map()
        spacing = 1. / (img_sz - 1)
        norm_mesh_list = self.normalize_mesh(mesh_list, spacing)
        norm_mesh_np = np.array(norm_mesh_list)  # N*3
        norm_mesh_np = np.transpose(norm_mesh_np)  # 3*N
        norm_mesh = MyTensor(norm_mesh_np).view([1, 3, -1, 1, 1]) * 2 - 1
        warped_mesh = compute_warped_image_multiNC(inv_map,
                                                   norm_mesh,
                                                   spacing,
                                                   spline_order=1,
                                                   zero_boundary=False,
                                                   use_01_input=False)

        warped_mesh_np = warped_mesh.cpu().numpy()[0, :, :, 0, 0]
        warped_mesh_np = np.transpose(warped_mesh_np)  # N*3
        warped_mesh_orig_np = self.get_mesh_in_original_space(
            warped_mesh_np, spacing)
        warped_mesh_original_list = [
            warped_mesh_orig_np[i] for i in range(len(mesh_list))
        ]
        return warped_mesh_original_list
Esempio n. 14
0
def resample_image(I,
                   spacing,
                   desiredSize,
                   spline_order=1,
                   zero_boundary=False,
                   identity_map=None):
    """
    Resample an image to a given desired size

    :param I: Input image (expected to be of BxCxXxYxZ format)
    :param spacing: array describing the spatial spacing
    :param desiredSize: array for the desired size (excluding B and C, i.e, 1 entry for 1D, 2 for 2D, and 3 for 3D)
    :return: returns a tuple: the downsampled image, the new spacing after downsampling
    """
    if len(I.shape) != len(desiredSize) + 2:
        desiredSize = desiredSize[2:]
    sz = np.array(list(I.size()))
    # check that the batch size and the number of channels is the same
    nrOfI = sz[0]
    nrOfC = sz[1]

    desiredSizeNC = np.array([nrOfI, nrOfC] + list(desiredSize))

    newspacing = spacing * ((sz[2::].astype('float') - 1.) /
                            (desiredSizeNC[2::].astype('float') - 1.)
                            )  ###########################################
    if identity_map is not None:
        idDes = identity_map
    else:
        idDes = torch.from_numpy(
            py_utils.identity_map_multiN(desiredSizeNC,
                                         newspacing)).to(I.device)
    # now use this map for resampling
    ID = py_utils.compute_warped_image_multiNC(I, idDes, newspacing,
                                               spline_order, zero_boundary)

    return ID, newspacing
Esempio n. 15
0
    def generate_single_res(self, moving, l_moving, momentum, init_weight,
                            initial_map, initial_inverse_map, fname, t_aug,
                            output_path, moving_path):
        params = self.mermaid_setting
        params['model']['registration_model']['forward_model']['tTo'] = t_aug

        # here we assume the momentum is computed at low_resol_factor=0.5
        if momentum is not None:
            input_img_sz = [1, 1] + [int(sz * 2) for sz in momentum.shape[2:]]
        else:
            input_img_sz = list(moving.shape)
            momentum_sz_low = [1, 3] + [
                int(dim / self.rand_momentum_shrink_factor)
                for dim in input_img_sz[2:]
            ]
            momentum_sz = [1, 3] + [int(dim / 2) for dim in input_img_sz[2:]]
            momentum = (np.random.rand(*momentum_sz_low) * 2 -
                        1) * self.magnitude
            mom_spacing = 1. / (np.array(momentum_sz_low[2:]) - 1)
            momentum = torch.Tensor(momentum).cuda()
            momentum, _ = resample_image(momentum,
                                         mom_spacing,
                                         momentum_sz,
                                         spline_order=1,
                                         zero_boundary=True)
        if self.resize_output != [-1, -1, -1]:
            momentum_sz = [1, 3] + [int(dim / 2) for dim in self.resize_output]
            mom_spacing = 1. / (np.array(momentum_sz[2:]) - 1)
            momentum, _ = resample_image(momentum,
                                         mom_spacing,
                                         momentum_sz,
                                         spline_order=1,
                                         zero_boundary=True)
            input_img_sz = [1, 1] + [int(sz * 2) for sz in momentum.shape[2:]]

        org_spacing = 1.0 / (np.array(moving.shape[2:]) - 1)
        input_spacing = 1.0 / (np.array(input_img_sz[2:]) - 1)
        size_diff = not input_img_sz == list(moving.shape)
        if size_diff:
            input_img, _ = resample_image(moving, org_spacing, input_img_sz)
        else:
            input_img = moving
        low_initial_map = None
        low_init_inverse_map = None
        if initial_map is not None:
            low_initial_map, _ = resample_image(
                initial_map, input_spacing, [1, 3] + list(momentum.shape[2:]))
        if initial_inverse_map is not None:
            low_init_inverse_map, _ = resample_image(initial_inverse_map,
                                                     input_spacing, [1, 3] +
                                                     list(momentum.shape[2:]))
        individual_parameters = dict(m=momentum, local_weights=init_weight)
        sz = np.array(input_img.shape)
        extra_info = None
        visual_param = None
        res = evaluate_model(input_img,
                             input_img,
                             sz,
                             input_spacing,
                             use_map=True,
                             compute_inverse_map=self.compute_inverse,
                             map_low_res_factor=0.5,
                             compute_similarity_measure_at_low_res=False,
                             spline_order=1,
                             individual_parameters=individual_parameters,
                             shared_parameters=None,
                             params=params,
                             extra_info=extra_info,
                             visualize=False,
                             visual_param=visual_param,
                             given_weight=False,
                             init_map=initial_map,
                             lowres_init_map=low_initial_map,
                             init_inverse_map=initial_inverse_map,
                             lowres_init_inverse_map=low_init_inverse_map)
        phi = res[1]
        phi_new = phi
        if size_diff:
            phi_new, _ = resample_image(phi, input_spacing,
                                        [1, 3] + list(moving.shape[2:]))
        warped = compute_warped_image_multiNC(moving,
                                              phi_new,
                                              org_spacing,
                                              spline_order=1,
                                              zero_boundary=True)
        if initial_inverse_map is not None and self.affine_back_to_original_postion:
            # here we take zero boundary boundary which need two step image interpolation
            warped = compute_warped_image_multiNC(warped,
                                                  initial_inverse_map,
                                                  org_spacing,
                                                  spline_order=1,
                                                  zero_boundary=True)
            phi_new = compute_warped_image_multiNC(phi_new,
                                                   initial_inverse_map,
                                                   org_spacing,
                                                   spline_order=1)
        save_image_with_given_reference(warped, [moving_path], output_path,
                                        [fname + '_image'])
        if l_moving is not None:
            # we assume the label doesnt lie at the boundary
            l_warped = compute_warped_image_multiNC(l_moving,
                                                    phi_new,
                                                    org_spacing,
                                                    spline_order=0,
                                                    zero_boundary=True)
            save_image_with_given_reference(l_warped, [moving_path],
                                            output_path, [fname + '_label'])

        if self.save_tf_map:
            save_deformation(phi_new, output_path, [fname + '_phi_map'])
            if self.compute_inverse:
                phi_inv = res[2]
                inv_phi_new = phi_inv
                if self.affine_back_to_original_postion:
                    print(
                        "Cannot compute the inverse map when affine back to the source image position"
                    )
                    return
                if size_diff:
                    inv_phi_new, _ = resample_image(phi_inv, input_spacing,
                                                    [1, 3] +
                                                    list(moving.shape[2:]))
                save_deformation(inv_phi_new, output_path,
                                 [fname + '_inv_map'])
Esempio n. 16
0
inv_warped_itk = sitk_grid_sampling(moving_itk, target_itk, inv_trans_itk)

moving_np = sitk.GetArrayFromImage(moving_itk).astype(np.float32)
target_np = sitk.GetArrayFromImage(target_itk).astype(np.float32)
mermaid_phi = sitk.GetArrayFromImage(
    sitk.ReadImage(mermaid_transform_path)).transpose(3, 2, 1, 0)
mermaid_inv_phi = sitk.GetArrayFromImage(
    sitk.ReadImage(mermaid_inv_transform_path)).transpose(3, 2, 1, 0)

phi_sz = np.array(mermaid_phi.shape)
spacing = 1. / (np.array(phi_sz[1:]) - 1)
moving = torch.from_numpy(moving_np[None][None])
mermaid_phi = torch.from_numpy(mermaid_phi[None])
warped_mermaid = compute_warped_image_multiNC(moving,
                                              mermaid_phi,
                                              spacing,
                                              1,
                                              zero_boundary=True)

inv_phi_sz = np.array(mermaid_inv_phi.shape)
spacing = 1. / (np.array(inv_phi_sz[1:]) - 1)
target = torch.from_numpy(target_np[None][None])
mermaid_inv_phi = torch.from_numpy(mermaid_inv_phi[None])
inv_warped_mermaid = compute_warped_image_multiNC(target,
                                                  mermaid_inv_phi,
                                                  spacing,
                                                  1,
                                                  zero_boundary=True)

output_path = "/playpen-raid1/zyshen/data/demo_for_lung_reg"
sitk.WriteImage(warped_itk, os.path.join(output_path, "warped_itk.nii.gz"))
Esempio n. 17
0
    def do_mermaid_reg(self,
                       mermaid_unit,
                       criterion,
                       s,
                       t,
                       m,
                       phi,
                       low_s=None,
                       low_t=None,
                       inv_map=None):
        """
        perform mermaid registrtion unit

        :param s: source image
        :param t: target image
        :param m: initial momentum
        :param phi: initial deformation field
        :param low_s: downsampled source
        :param low_t: downsampled target
        :param inv_map: inversed map
        :return:  warped image, transformation map
        """
        if self.mermaid_low_res_factor is not None:
            self.set_mermaid_param(mermaid_unit, criterion, low_s, low_t, m, s)
            if not self.compute_inverse_map:
                maps = mermaid_unit(
                    self.lowRes_fn(phi),
                    low_s,
                    variables_from_optimizer={'epoch': self.epoch})
            else:
                maps, inverse_maps = mermaid_unit(
                    self.lowRes_fn(phi),
                    low_s,
                    phi_inv=self.lowRes_fn(inv_map),
                    variables_from_optimizer={'epoch': self.epoch})

            desiredSz = self.img_sz
            rec_phiWarped = get_resampled_image(maps,
                                                self.lowResSpacing,
                                                desiredSz,
                                                1,
                                                zero_boundary=False,
                                                identity_map=self.identityMap)
            if self.compute_inverse_map:
                self.inverse_map = get_resampled_image(
                    inverse_maps,
                    self.lowResSpacing,
                    desiredSz,
                    1,
                    zero_boundary=False,
                    identity_map=self.identityMap)

        else:
            self.set_mermaid_param(mermaid_unit, criterion, s, t, m, s)
            if not self.compute_inverse_map:
                maps = mermaid_unit(
                    phi, s, variables_from_optimizer={'epoch': self.epoch})
            else:
                maps, self.inverse_map = mermaid_unit(
                    phi,
                    s,
                    phi_inv=inv_map,
                    variables_from_optimizer={'epoch': self.epoch})
            rec_phiWarped = maps
        rec_IWarped = py_utils.compute_warped_image_multiNC(s,
                                                            rec_phiWarped,
                                                            self.spacing,
                                                            1,
                                                            zero_boundary=True)
        self.rec_phiWarped = rec_phiWarped

        return rec_IWarped, rec_phiWarped
Esempio n. 18
0
        def compute_warped_image_label(input,
                                       warped_pth,
                                       warped_type,
                                       inv_phi_pth,
                                       inv_switcher,
                                       num_max=50,
                                       weight_for_orig_img=0):
            warped_pth_list = glob(os.path.join(warped_pth, warped_type))
            num_max = min(len(warped_pth_list), num_max)
            inv_phi_pth_list = [
                pth.replace(warped_pth, inv_phi_pth).replace(*inv_switcher)
                for pth in warped_pth_list
            ]
            f = lambda pth: sitk.GetArrayFromImage(sitk.ReadImage(pth))
            fname = get_file_name(self.fname[0])
            f_warped = lambda x: get_file_name(x).find(fname + '_') == 0
            warped_sub_list = list(filter(f_warped, warped_pth_list))
            inv_phi_sub_list = list(filter(f_warped, inv_phi_pth_list))
            warped_sub_list = warped_sub_list[:num_max]
            inv_phi_sub_list = inv_phi_sub_list[:num_max]
            num_aug = len(warped_sub_list)
            warped_list = [f(pth) for pth in warped_sub_list]
            inv_phi_list = [f(pth) for pth in inv_phi_sub_list]
            warped_img = np.stack(warped_list, 0)[:, None]
            #warped_img = torch.Tensor(warped_img)*2-1.
            warped_img = self.normalize_input(warped_img,
                                              None)  #self.file_path[0][0])
            warped_img = torch.Tensor(warped_img)
            inv_phi = np.stack(inv_phi_list, 0)
            inv_phi = np.transpose(inv_phi, (0, 4, 3, 2, 1))
            inv_phi = torch.Tensor(inv_phi)
            img_input_sz = self.opt["dataset"]["img_after_resize"]
            differ_sz = any(
                np.array(warped_img.shape[2:]) != np.array(img_input_sz))

            sz = np.array(self.img_sz)
            spacing = 1. / (sz - 1)
            output_np = np.zeros([1, self.num_class] + self.img_sz)
            if weight_for_orig_img != 0:
                tzero_img = self.get_assemble_pred_for_ensemble(input)
                tzero_pred = self.partition.assemble_multi_torch(
                    tzero_img, image_size=self.img_sz)
                output_np = tzero_pred.cpu().numpy() * float(
                    round(weight_for_orig_img * num_aug))

            for i in range(num_aug):
                if differ_sz:
                    warped_img_cur, _ = resample_image(
                        warped_img[i:i + 1].cuda(), [1, 1, 1],
                        [1, 3] + self.img_sz)
                    inv_phi_cur, _ = resample_image(inv_phi[i:i + 1].cuda(),
                                                    [1, 1, 1],
                                                    [1, 1] + self.img_sz)
                    warped_img_cur = warped_img_cur.detach().cpu()
                    inv_phi_cur = inv_phi_cur.detach().cpu()
                else:
                    warped_img_cur = warped_img[i:i + 1]
                    inv_phi_cur = inv_phi[i:i + 1]
                sample = {"image": [warped_img_cur[0, 0].numpy()]}
                sample_p = corr_partition_pool(sample)
                pred_patched = self.get_assemble_pred_for_ensemble(
                    torch.Tensor(sample_p["image"]).cuda())
                pred_patched = self.partition.assemble_multi_torch(
                    pred_patched, image_size=self.img_sz)
                pred_patched = torch.nn.functional.softmax(pred_patched, 1)
                pred_patched = compute_warped_image_multiNC(
                    pred_patched.cuda(),
                    inv_phi_cur.cuda(),
                    spacing,
                    spline_order=1,
                    zero_boundary=True)
                output_np += pred_patched.cpu().numpy()
            res = torch.max(torch.Tensor(output_np), 1)[1]
            return res[None]
Esempio n. 19
0
    def nonp_optimization(self):
        """
        call non-parametric image registration in mermaid
        if the affine registration is performed first, the affine transformation map would be taken as the initial map
        if the init weight on mutli-gaussian regularizer are set, the initial weight map would be computed from the label map, make sure the model called support spatial variant regularizer

        :return: warped image, transformation map, affined image, loss(None)
        """
        affine_map = None
        if self.affine_on:
            affine_map = self.si.opt.optimizer.ssOpt.get_map()

        self.si = SI.RegisterImagePair()
        extra_info = pars.ParameterDict()
        extra_info['pair_name'] = self.fname_list
        self.si.opt = None
        if affine_map is not None:
            self.si.set_initial_map(affine_map.detach(), self.inversed_map)

        if self.use_init_weight:
            init_weight = get_init_weight_from_label_map(
                self.l_moving, self.spacing, self.weights_for_bg,
                self.weights_for_fg)
            init_weight = py_utils.compute_warped_image_multiNC(
                init_weight,
                affine_map,
                self.spacing,
                spline_order=1,
                zero_boundary=False)
            self.si.set_weight_map(init_weight.detach(), freeze_weight=True)

        if self.saved_mermaid_setting_path is None:
            self.saved_mermaid_setting_path = self.save_setting(
                self.setting_for_mermaid_nonp, self.record_path,
                "nonp_setting.json")
        cur_mermaid_json_saving_path = (os.path.join(self.record_path,
                                                     'cur_settings_nonp.json'),
                                        os.path.join(
                                            self.record_path,
                                            'cur_settings_nonp_comment.json'))
        self.si.register_images(
            self.moving,
            self.target,
            self.spacing,
            extra_info=extra_info,
            LSource=self.l_moving,
            LTarget=self.l_target,
            visualize_step=None,
            use_multi_scale=True,
            rel_ftol=0,
            compute_inverse_map=self.compute_inverse_map,
            json_config_out_filename=cur_mermaid_json_saving_path,
            params=self.saved_mermaid_setting_path
        )  #'../mermaid_settings/cur_settings_svf_dipr.json'
        self.afimg_or_afparam = self.output  # here return the affine image
        self.output = self.si.get_warped_image()
        self.phi = self.si.opt.optimizer.ssOpt.get_map()
        # for i in range(self.dim):
        #     self.phi[:,i,...] = self.phi[:,i,...]/ ((self.input_img_sz[i]-1)*self.spacing[i])

        if self.compute_inverse_map:
            self.inversed_map = self.si.get_inverse_map().detach()
        return self.output.detach_(), self.phi.detach_(
        ), self.afimg_or_afparam.detach_(
        ) if self.afimg_or_afparam is not None else None, None
Esempio n. 20
0
    def generate_aug_data(self, path_list, fname_list, init_weight_path_list,
                          output_path):
        """
        here we use the low-interface of mermaid to get efficient low-res- propagration (avod saving phi and inverse phi as well as the precision loss from unnecessary upsampling and downsampling
        ) which provide high precision in maps
        """
        def create_mermaid_model(mermaid_json_pth,
                                 img_sz,
                                 compute_inverse=True):
            import mermaid.model_factory as py_mf
            spacing = 1. / (np.array(img_sz[2:]) - 1)
            params = pars.ParameterDict()
            params.load_JSON(
                mermaid_json_pth)  # ''../easyreg/cur_settings_svf.json')
            model_name = params['model']['registration_model']['type']
            params.print_settings_off()
            mermaid_low_res_factor = 0.5
            lowResSize = get_res_size_from_size(img_sz, mermaid_low_res_factor)
            lowResSpacing = get_res_spacing_from_spacing(
                spacing, img_sz, lowResSize)
            ##
            mf = py_mf.ModelFactory(img_sz, spacing, lowResSize, lowResSpacing)
            model, criterion = mf.create_registration_model(
                model_name, params['model'], compute_inverse_map=True)
            lowres_id = identity_map_multiN(lowResSize, lowResSpacing)
            lowResIdentityMap = torch.from_numpy(lowres_id).cuda()

            _id = identity_map_multiN(img_sz, spacing)
            identityMap = torch.from_numpy(_id).cuda()
            mermaid_unit_st = model.cuda()
            mermaid_unit_st.associate_parameters_with_module()
            return mermaid_unit_st, criterion, lowResIdentityMap, lowResSize, lowResSpacing, identityMap, spacing

        def _set_mermaid_param(mermaid_unit, m):
            mermaid_unit.m.data = m

        def _do_mermaid_reg(mermaid_unit,
                            low_phi,
                            m,
                            low_s=None,
                            low_inv_phi=None):
            with torch.no_grad():
                _set_mermaid_param(mermaid_unit, m)
                low_phi = mermaid_unit(low_phi, low_s, phi_inv=low_inv_phi)
            return low_phi

        def get_momentum_name(momentum_path):
            fname = get_file_name(momentum_path)
            fname = fname.replace("_0000_Momentum", '')
            return fname

        max_aug_num = self.max_aug_num
        rand_w_t = self.rand_w_t
        t_range = self.t_range
        t_span = t_range[1] - t_range[0]
        K = self.K

        num_pair = len(path_list)
        assert init_weight_path_list is None, "init weight has not supported yet"
        # load all momentums for atlas to images
        read_image = lambda x: sitk.GetArrayFromImage(sitk.ReadImage(x))
        atlas_to_momentum_path_list = list(
            filter(
                lambda x: "Momentum" in x and get_file_name(x).find("atlas") ==
                0, glob(os.path.join(self.atlas_to_folder, "*nii.gz"))))
        to_atlas_momentum_path_list = list(
            filter(
                lambda x: "Momentum" in x and get_file_name(x).find("atlas") !=
                0, glob(os.path.join(self.to_atlas_folder, "*nii.gz"))))
        atlas_to_momentum_list = [
            torch.Tensor(
                read_image(atlas_momentum_pth).transpose()[None]).cuda()
            for atlas_momentum_pth in atlas_to_momentum_path_list
        ]
        to_atlas_momentum_list = [
            torch.Tensor(
                read_image(atlas_momentum_pth).transpose()[None]).cuda()
            for atlas_momentum_pth in to_atlas_momentum_path_list
        ]
        moving_example = read_image(path_list[0][0])
        img_sz = list(moving_example.shape)
        mermaid_unit_st, criterion, lowResIdentityMap, lowResSize, lowResSpacing, identityMap, spacing = create_mermaid_model(
            mermaid_setting_path, [1, 1] + img_sz, self.compute_inverse)

        for i in range(num_pair):
            fname = fname_list[i] if fname_list is not None else None
            moving, l_moving, moving_name = self.get_input(
                path_list[i], fname, None)
            # get the transformation to atlas, which should simply load the transformation map
            low_moving = get_resampled_image(moving,
                                             None,
                                             lowResSize,
                                             1,
                                             zero_boundary=True,
                                             identity_map=lowResIdentityMap)
            init_map = lowResIdentityMap.clone()
            init_inverse_map = lowResIdentityMap.clone()
            index = list(
                filter(lambda x: moving_name in x,
                       to_atlas_momentum_path_list))[0]
            index = to_atlas_momentum_path_list.index(index)
            # here we only interested in forward the map, so the moving image doesn't affect
            mermaid_unit_st.integrator.cparams['tTo'] = 1.0
            low_phi_to_atlas, low_inverse_phi_to_atlas = _do_mermaid_reg(
                mermaid_unit_st,
                init_map,
                to_atlas_momentum_list[index],
                low_moving,
                low_inv_phi=init_inverse_map)
            num_aug = max(round(max_aug_num / num_pair), 1) if rand_w_t else 1

            for _ in range(num_aug):
                num_momentum = len(atlas_to_momentum_list)
                if rand_w_t:
                    t_aug_list = [random.random() * t_span + t_range[0]]
                    weight = np.array([random.random() for _ in range(K)])
                    weight_list = [weight / np.sum(weight)]
                    selected_index = random.sample(list(range(num_momentum)),
                                                   K)
                else:
                    raise ValueError(
                        "In atlas augmentation mode, the data interpolation is disabled"
                    )
                for t_aug in t_aug_list:
                    if t_aug == 0:
                        continue
                    for weight in weight_list:
                        momentum = torch.zeros_like(atlas_to_momentum_list[0])
                        fname = moving_name + "_to_"
                        suffix = ""
                        for k in range(K):
                            momentum += weight[k] * atlas_to_momentum_list[
                                selected_index[k]]
                            fname += get_momentum_name(
                                atlas_to_momentum_path_list[
                                    selected_index[k]]) + '_'
                            suffix += '{:.4f}_'.format(weight[k])

                        fname = fname + suffix + 't_{:.2f}'.format(t_aug)
                        fname = fname.replace('.', 'd')

                        mermaid_unit_st.integrator.cparams['tTo'] = t_aug
                        low_phi_atlas_to, low_inverse_phi_atlas_to = _do_mermaid_reg(
                            mermaid_unit_st,
                            low_phi_to_atlas.clone(),
                            momentum,
                            low_moving,
                            low_inv_phi=low_inverse_phi_to_atlas.clone())
                        foward_map = get_resampled_image(
                            low_phi_atlas_to,
                            lowResSpacing, [1, 3] + img_sz,
                            1,
                            zero_boundary=False,
                            identity_map=identityMap)
                        inverse_map = get_resampled_image(
                            low_inverse_phi_atlas_to,
                            lowResSpacing, [1, 3] + img_sz,
                            1,
                            zero_boundary=False,
                            identity_map=identityMap)
                        warped = compute_warped_image_multiNC(
                            moving,
                            foward_map,
                            spacing,
                            spline_order=1,
                            zero_boundary=True)
                        if l_moving is not None:
                            l_warped = compute_warped_image_multiNC(
                                l_moving,
                                foward_map,
                                spacing,
                                spline_order=0,
                                zero_boundary=True)
                            save_image_with_given_reference(
                                l_warped, [path_list[i][0]], output_path,
                                [fname + '_label'])
                        save_image_with_given_reference(
                            warped, [path_list[i][0]], output_path,
                            [fname + '_image'])
                        if self.save_tf_map:
                            if self.compute_inverse:
                                # save_deformation(foward_map, output_path, [fname + '_phi'])
                                save_deformation(inverse_map, output_path,
                                                 [fname + '_inv_phi'])
Esempio n. 21
0
 def compute_loss():
     # warps map_t with inv_map, if it is the inverse should result in the identity map
     wmap = utils.compute_warped_image_multiNC(map_t, invmap_t, spacing)
     current_loss = ((wmap - id_t)**2).sum()
     return current_loss