Ejemplo n.º 1
0
def compute_atlas_label(lsource_folder_path, to_atlas_folder_pth, atlas_type,
                        atlas_to_l_switcher, output_folder):
    to_atlas_pth_list = glob(
        os.path.join(to_atlas_folder_pth, "*" + atlas_type))[:100]
    to_atlas_name_list = [
        get_file_name(to_atlas_pth) for to_atlas_pth in to_atlas_pth_list
    ]
    l_pth_list = [
        os.path.join(lsource_folder_path,
                     name.replace(*atlas_to_l_switcher) + '.nii.gz')
        for name in to_atlas_name_list
    ]
    fr_sitk = lambda x: sitk.GetArrayFromImage(sitk.ReadImage(x))
    l_list = [fr_sitk(pth)[None] for pth in l_pth_list]
    to_atlas_list = [np.transpose(fr_sitk(pth)) for pth in to_atlas_pth_list]
    l = np.stack(l_list).astype(np.int64)
    num_c = len(np.unique(l))
    to_atlas = np.stack(to_atlas_list)
    l = torch.LongTensor(l)
    to_atlas = torch.Tensor(to_atlas)
    l_onehot = make_one_hot(l, C=num_c)
    spacing = 1. / (np.array(l.shape[2:]) - 1)
    l_onehot = l_onehot.to(torch.float32)
    warped_one_hot = compute_warped_image_multiNC(l_onehot,
                                                  to_atlas,
                                                  spacing=spacing,
                                                  spline_order=1,
                                                  zero_boundary=True)
    sum_one_hot = torch.sum(warped_one_hot, 0, keepdim=True)
    voting = torch.max(torch.Tensor(sum_one_hot), 1)[1][None].to(torch.float32)
    save_image_with_given_reference(voting, [l_pth_list[0]], output_folder,
                                    ["atlas_label"])
Ejemplo n.º 2
0
def compute_warped_image_label(img_label_txt_pth,phi_pth,phi_type, saving_pth):
    img_label_pth_list = read_txt_into_list(img_label_txt_pth)
    phi_pth_list = glob(os.path.join(phi_pth,phi_type))
    f = lambda pth: sitk.GetArrayFromImage(sitk.ReadImage(pth))
    img_list = [f(pth[0]) for pth in img_label_pth_list]
    label_list = [f(pth[1]) for pth in img_label_pth_list]
    num_img = len(img_list)
    for i in range(num_img):
        fname = get_file_name(img_label_pth_list[i][0])
        img = torch.Tensor(img_list[i][None][None])
        label = torch.Tensor(label_list[i][None][None])
        f_phi = lambda x: get_file_name(x).find(fname)==0
        phi_sub_list = list(filter(f_phi, phi_pth_list))
        num_aug = len(phi_sub_list)
        phi_list = [f(pth) for pth in phi_sub_list]
        img = img.repeat(num_aug,1,1,1,1)
        label = label.repeat(num_aug,1,1,1,1)
        phi = np.stack(phi_list,0)
        phi = np.transpose(phi,(0,4,3,2,1))
        phi = torch.Tensor(phi)
        sz = np.array(img.shape[2:])
        spacing = 1./(sz-1)
        phi, _ = resample_image(phi,spacing,[1,3]+list(img.shape[2:]))
        warped_img = compute_warped_image_multiNC(img,phi,spacing,spline_order=1,zero_boundary=True)
        warped_label = compute_warped_image_multiNC(label,phi,spacing,spline_order=0,zero_boundary=True)
        save_image_with_given_reference(warped_img,[img_label_pth_list[i][0]]*num_aug,saving_pth,[get_file_name(pth).replace("_phi","")+'_warped' for pth in phi_sub_list])
        save_image_with_given_reference(warped_label,[img_label_pth_list[i][0]]*num_aug,saving_pth,[get_file_name(pth).replace("_phi","")+'_label' for pth in phi_sub_list])
Ejemplo n.º 3
0
moving = torch.from_numpy(moving_np[None][None])
mermaid_phi = torch.from_numpy(mermaid_phi[None])
warped_mermaid = compute_warped_image_multiNC(moving,
                                              mermaid_phi,
                                              spacing,
                                              1,
                                              zero_boundary=True)

inv_phi_sz = np.array(mermaid_inv_phi.shape)
spacing = 1. / (np.array(inv_phi_sz[1:]) - 1)
target = torch.from_numpy(target_np[None][None])
mermaid_inv_phi = torch.from_numpy(mermaid_inv_phi[None])
inv_warped_mermaid = compute_warped_image_multiNC(target,
                                                  mermaid_inv_phi,
                                                  spacing,
                                                  1,
                                                  zero_boundary=True)

output_path = "/playpen-raid1/zyshen/data/demo_for_lung_reg"
sitk.WriteImage(warped_itk, os.path.join(output_path, "warped_itk.nii.gz"))
ires.save_image_with_given_reference(warped_mermaid,
                                     reference_list=[target_path],
                                     path=output_path,
                                     fname=["warped_mermaid"])
sitk.WriteImage(inv_warped_itk,
                os.path.join(output_path, "inv_warped_itk.nii.gz"))
ires.save_image_with_given_reference(inv_warped_mermaid,
                                     reference_list=[moving_path],
                                     path=output_path,
                                     fname=["inv_warped_mermaid"])
Ejemplo n.º 4
0
    def _save_image_into_original_sz_with_given_reference(self, pair_path, phi, inverse_phi=None, use_01=False):
        """
        the images (moving, target, warped, transformation map, inverse transformation map world coord[0,1] ) are saved in record_path/original_sz

        :param pair_path: list, moving image path, target image path
        :param phi: transformation map BDXYZ
        :param inverse_phi: inverse transformation map BDXYZ
        :param use_01: indicate the transformation use [0,1] coord or [-1,1] coord
        :return:
        """
        save_original_image_by_type = self.save_original_image_by_type
        save_s, save_t, save_w, save_phi, save_w_inv, save_phi_inv, save_disp, save_extra_not_used_here = save_original_image_by_type
        spacing = self.spacing
        moving_reference_list = pair_path[0]
        target_reference_list = pair_path[1]
        moving_l_reference_list  = None
        target_l_reference_list = None
        if len(pair_path)==4:
            moving_l_reference_list = pair_path[2]
            target_l_reference_list = pair_path[3]
        phi = (phi + 1) / 2. if not use_01 else phi
        new_phi, warped, warped_l, new_spacing = ires.resample_warped_phi_and_image(moving_reference_list, moving_l_reference_list, phi, spacing)
        saving_original_sz_path = os.path.join(self.record_path, 'original_sz')
        os.makedirs(saving_original_sz_path, exist_ok=True)
        if save_phi:
            fname_list = list(self.fname_list)
            ires.save_transfrom(new_phi, new_spacing, saving_original_sz_path, fname_list)
        if save_w:
            fname_list = [fname + '_warped' for fname in self.fname_list]
            ires.save_image_with_given_reference(warped, target_reference_list, saving_original_sz_path, fname_list)
            fname_list = [fname + '_warped_l' for fname in self.fname_list]
            ires.save_image_with_given_reference(warped_l, target_l_reference_list, saving_original_sz_path, fname_list)

        if save_s:
            fname_list = [fname + '_moving' for fname in self.fname_list]
            ires.save_image_with_given_reference(None, moving_reference_list, saving_original_sz_path, fname_list)
            fname_list = [fname + '_moving_l' for fname in self.fname_list]
            ires.save_image_with_given_reference(None, moving_l_reference_list, saving_original_sz_path, fname_list)
        if save_t:
            fname_list = [fname + '_target' for fname in self.fname_list]
            ires.save_image_with_given_reference(None, target_reference_list, saving_original_sz_path, fname_list)
            fname_list = [fname + '_target_l' for fname in self.fname_list]
            ires.save_image_with_given_reference(None, target_l_reference_list, saving_original_sz_path, fname_list)
        if inverse_phi is not None:
            inverse_phi = (inverse_phi + 1) / 2. if not use_01 else inverse_phi
            new_inv_phi, inv_warped, inv_warped_l, new_spacing = ires.resample_warped_phi_and_image(target_reference_list,target_l_reference_list, inverse_phi, spacing)
            if save_phi_inv:
                fname_list = [fname + '_inv' for fname in self.fname_list]
                ires.save_transfrom(new_inv_phi, new_spacing, saving_original_sz_path, fname_list)
            if save_w_inv:
                fname_list = [fname + '_inv_warped' for fname in self.fname_list]
                ires.save_image_with_given_reference(inv_warped, moving_reference_list, saving_original_sz_path, fname_list)
                fname_list = [fname + '_inv_warped_l' for fname in self.fname_list]
                ires.save_image_with_given_reference(inv_warped_l, moving_l_reference_list, saving_original_sz_path,
                                                     fname_list)
            if save_disp:
                fname_list = [fname + '_inv_disp' for fname in self.fname_list]
                id_map =  gen_identity_map( warped.shape[2:], resize_factor=1., normalized=True).cuda()
                id_map = (id_map[None]+1)/2.
                inv_disp = new_inv_phi -id_map
                ires.save_transform_with_reference(inv_disp, new_spacing, target_reference_list,moving_reference_list, path=saving_original_sz_path, fname_list=fname_list,
                                              save_disp_into_itk_format=True)
                fname_list = [fname + '_disp' for fname in self.fname_list]
                disp = new_phi - id_map
                ires.save_transform_with_reference(disp, new_spacing, moving_reference_list,target_reference_list,
                                                   path=saving_original_sz_path, fname_list=fname_list,
                                                   save_disp_into_itk_format=True)
Ejemplo n.º 5
0
    def generate_aug_data(self, path_list, fname_list, init_weight_path_list,
                          output_path):
        """
        here we use the low-interface of mermaid to get efficient low-res- propagration (avod saving phi and inverse phi as well as the precision loss from unnecessary upsampling and downsampling
        ) which provide high precision in maps
        """
        def create_mermaid_model(mermaid_json_pth,
                                 img_sz,
                                 compute_inverse=True):
            import mermaid.model_factory as py_mf
            spacing = 1. / (np.array(img_sz[2:]) - 1)
            params = pars.ParameterDict()
            params.load_JSON(
                mermaid_json_pth)  # ''../easyreg/cur_settings_svf.json')
            model_name = params['model']['registration_model']['type']
            params.print_settings_off()
            mermaid_low_res_factor = 0.5
            lowResSize = get_res_size_from_size(img_sz, mermaid_low_res_factor)
            lowResSpacing = get_res_spacing_from_spacing(
                spacing, img_sz, lowResSize)
            ##
            mf = py_mf.ModelFactory(img_sz, spacing, lowResSize, lowResSpacing)
            model, criterion = mf.create_registration_model(
                model_name, params['model'], compute_inverse_map=True)
            lowres_id = identity_map_multiN(lowResSize, lowResSpacing)
            lowResIdentityMap = torch.from_numpy(lowres_id).cuda()

            _id = identity_map_multiN(img_sz, spacing)
            identityMap = torch.from_numpy(_id).cuda()
            mermaid_unit_st = model.cuda()
            mermaid_unit_st.associate_parameters_with_module()
            return mermaid_unit_st, criterion, lowResIdentityMap, lowResSize, lowResSpacing, identityMap, spacing

        def _set_mermaid_param(mermaid_unit, m):
            mermaid_unit.m.data = m

        def _do_mermaid_reg(mermaid_unit,
                            low_phi,
                            m,
                            low_s=None,
                            low_inv_phi=None):
            with torch.no_grad():
                _set_mermaid_param(mermaid_unit, m)
                low_phi = mermaid_unit(low_phi, low_s, phi_inv=low_inv_phi)
            return low_phi

        def get_momentum_name(momentum_path):
            fname = get_file_name(momentum_path)
            fname = fname.replace("_0000_Momentum", '')
            return fname

        max_aug_num = self.max_aug_num
        rand_w_t = self.rand_w_t
        t_range = self.t_range
        t_span = t_range[1] - t_range[0]
        K = self.K

        num_pair = len(path_list)
        assert init_weight_path_list is None, "init weight has not supported yet"
        # load all momentums for atlas to images
        read_image = lambda x: sitk.GetArrayFromImage(sitk.ReadImage(x))
        atlas_to_momentum_path_list = list(
            filter(
                lambda x: "Momentum" in x and get_file_name(x).find("atlas") ==
                0, glob(os.path.join(self.atlas_to_folder, "*nii.gz"))))
        to_atlas_momentum_path_list = list(
            filter(
                lambda x: "Momentum" in x and get_file_name(x).find("atlas") !=
                0, glob(os.path.join(self.to_atlas_folder, "*nii.gz"))))
        atlas_to_momentum_list = [
            torch.Tensor(
                read_image(atlas_momentum_pth).transpose()[None]).cuda()
            for atlas_momentum_pth in atlas_to_momentum_path_list
        ]
        to_atlas_momentum_list = [
            torch.Tensor(
                read_image(atlas_momentum_pth).transpose()[None]).cuda()
            for atlas_momentum_pth in to_atlas_momentum_path_list
        ]
        moving_example = read_image(path_list[0][0])
        img_sz = list(moving_example.shape)
        mermaid_unit_st, criterion, lowResIdentityMap, lowResSize, lowResSpacing, identityMap, spacing = create_mermaid_model(
            mermaid_setting_path, [1, 1] + img_sz, self.compute_inverse)

        for i in range(num_pair):
            fname = fname_list[i] if fname_list is not None else None
            moving, l_moving, moving_name = self.get_input(
                path_list[i], fname, None)
            # get the transformation to atlas, which should simply load the transformation map
            low_moving = get_resampled_image(moving,
                                             None,
                                             lowResSize,
                                             1,
                                             zero_boundary=True,
                                             identity_map=lowResIdentityMap)
            init_map = lowResIdentityMap.clone()
            init_inverse_map = lowResIdentityMap.clone()
            index = list(
                filter(lambda x: moving_name in x,
                       to_atlas_momentum_path_list))[0]
            index = to_atlas_momentum_path_list.index(index)
            # here we only interested in forward the map, so the moving image doesn't affect
            mermaid_unit_st.integrator.cparams['tTo'] = 1.0
            low_phi_to_atlas, low_inverse_phi_to_atlas = _do_mermaid_reg(
                mermaid_unit_st,
                init_map,
                to_atlas_momentum_list[index],
                low_moving,
                low_inv_phi=init_inverse_map)
            num_aug = max(round(max_aug_num / num_pair), 1) if rand_w_t else 1

            for _ in range(num_aug):
                num_momentum = len(atlas_to_momentum_list)
                if rand_w_t:
                    t_aug_list = [random.random() * t_span + t_range[0]]
                    weight = np.array([random.random() for _ in range(K)])
                    weight_list = [weight / np.sum(weight)]
                    selected_index = random.sample(list(range(num_momentum)),
                                                   K)
                else:
                    raise ValueError(
                        "In atlas augmentation mode, the data interpolation is disabled"
                    )
                for t_aug in t_aug_list:
                    if t_aug == 0:
                        continue
                    for weight in weight_list:
                        momentum = torch.zeros_like(atlas_to_momentum_list[0])
                        fname = moving_name + "_to_"
                        suffix = ""
                        for k in range(K):
                            momentum += weight[k] * atlas_to_momentum_list[
                                selected_index[k]]
                            fname += get_momentum_name(
                                atlas_to_momentum_path_list[
                                    selected_index[k]]) + '_'
                            suffix += '{:.4f}_'.format(weight[k])

                        fname = fname + suffix + 't_{:.2f}'.format(t_aug)
                        fname = fname.replace('.', 'd')

                        mermaid_unit_st.integrator.cparams['tTo'] = t_aug
                        low_phi_atlas_to, low_inverse_phi_atlas_to = _do_mermaid_reg(
                            mermaid_unit_st,
                            low_phi_to_atlas.clone(),
                            momentum,
                            low_moving,
                            low_inv_phi=low_inverse_phi_to_atlas.clone())
                        foward_map = get_resampled_image(
                            low_phi_atlas_to,
                            lowResSpacing, [1, 3] + img_sz,
                            1,
                            zero_boundary=False,
                            identity_map=identityMap)
                        inverse_map = get_resampled_image(
                            low_inverse_phi_atlas_to,
                            lowResSpacing, [1, 3] + img_sz,
                            1,
                            zero_boundary=False,
                            identity_map=identityMap)
                        warped = compute_warped_image_multiNC(
                            moving,
                            foward_map,
                            spacing,
                            spline_order=1,
                            zero_boundary=True)
                        if l_moving is not None:
                            l_warped = compute_warped_image_multiNC(
                                l_moving,
                                foward_map,
                                spacing,
                                spline_order=0,
                                zero_boundary=True)
                            save_image_with_given_reference(
                                l_warped, [path_list[i][0]], output_path,
                                [fname + '_label'])
                        save_image_with_given_reference(
                            warped, [path_list[i][0]], output_path,
                            [fname + '_image'])
                        if self.save_tf_map:
                            if self.compute_inverse:
                                # save_deformation(foward_map, output_path, [fname + '_phi'])
                                save_deformation(inverse_map, output_path,
                                                 [fname + '_inv_phi'])
Ejemplo n.º 6
0
    def generate_single_res(self, moving, l_moving, momentum, init_weight,
                            initial_map, initial_inverse_map, fname, t_aug,
                            output_path, moving_path):
        params = self.mermaid_setting
        params['model']['registration_model']['forward_model']['tTo'] = t_aug

        # here we assume the momentum is computed at low_resol_factor=0.5
        if momentum is not None:
            input_img_sz = [1, 1] + [int(sz * 2) for sz in momentum.shape[2:]]
        else:
            input_img_sz = list(moving.shape)
            momentum_sz_low = [1, 3] + [
                int(dim / self.rand_momentum_shrink_factor)
                for dim in input_img_sz[2:]
            ]
            momentum_sz = [1, 3] + [int(dim / 2) for dim in input_img_sz[2:]]
            momentum = (np.random.rand(*momentum_sz_low) * 2 -
                        1) * self.magnitude
            mom_spacing = 1. / (np.array(momentum_sz_low[2:]) - 1)
            momentum = torch.Tensor(momentum).cuda()
            momentum, _ = resample_image(momentum,
                                         mom_spacing,
                                         momentum_sz,
                                         spline_order=1,
                                         zero_boundary=True)
        if self.resize_output != [-1, -1, -1]:
            momentum_sz = [1, 3] + [int(dim / 2) for dim in self.resize_output]
            mom_spacing = 1. / (np.array(momentum_sz[2:]) - 1)
            momentum, _ = resample_image(momentum,
                                         mom_spacing,
                                         momentum_sz,
                                         spline_order=1,
                                         zero_boundary=True)
            input_img_sz = [1, 1] + [int(sz * 2) for sz in momentum.shape[2:]]

        org_spacing = 1.0 / (np.array(moving.shape[2:]) - 1)
        input_spacing = 1.0 / (np.array(input_img_sz[2:]) - 1)
        size_diff = not input_img_sz == list(moving.shape)
        if size_diff:
            input_img, _ = resample_image(moving, org_spacing, input_img_sz)
        else:
            input_img = moving
        low_initial_map = None
        low_init_inverse_map = None
        if initial_map is not None:
            low_initial_map, _ = resample_image(
                initial_map, input_spacing, [1, 3] + list(momentum.shape[2:]))
        if initial_inverse_map is not None:
            low_init_inverse_map, _ = resample_image(initial_inverse_map,
                                                     input_spacing, [1, 3] +
                                                     list(momentum.shape[2:]))
        individual_parameters = dict(m=momentum, local_weights=init_weight)
        sz = np.array(input_img.shape)
        extra_info = None
        visual_param = None
        res = evaluate_model(input_img,
                             input_img,
                             sz,
                             input_spacing,
                             use_map=True,
                             compute_inverse_map=self.compute_inverse,
                             map_low_res_factor=0.5,
                             compute_similarity_measure_at_low_res=False,
                             spline_order=1,
                             individual_parameters=individual_parameters,
                             shared_parameters=None,
                             params=params,
                             extra_info=extra_info,
                             visualize=False,
                             visual_param=visual_param,
                             given_weight=False,
                             init_map=initial_map,
                             lowres_init_map=low_initial_map,
                             init_inverse_map=initial_inverse_map,
                             lowres_init_inverse_map=low_init_inverse_map)
        phi = res[1]
        phi_new = phi
        if size_diff:
            phi_new, _ = resample_image(phi, input_spacing,
                                        [1, 3] + list(moving.shape[2:]))
        warped = compute_warped_image_multiNC(moving,
                                              phi_new,
                                              org_spacing,
                                              spline_order=1,
                                              zero_boundary=True)
        if initial_inverse_map is not None and self.affine_back_to_original_postion:
            # here we take zero boundary boundary which need two step image interpolation
            warped = compute_warped_image_multiNC(warped,
                                                  initial_inverse_map,
                                                  org_spacing,
                                                  spline_order=1,
                                                  zero_boundary=True)
            phi_new = compute_warped_image_multiNC(phi_new,
                                                   initial_inverse_map,
                                                   org_spacing,
                                                   spline_order=1)
        save_image_with_given_reference(warped, [moving_path], output_path,
                                        [fname + '_image'])
        if l_moving is not None:
            # we assume the label doesnt lie at the boundary
            l_warped = compute_warped_image_multiNC(l_moving,
                                                    phi_new,
                                                    org_spacing,
                                                    spline_order=0,
                                                    zero_boundary=True)
            save_image_with_given_reference(l_warped, [moving_path],
                                            output_path, [fname + '_label'])

        if self.save_tf_map:
            save_deformation(phi_new, output_path, [fname + '_phi_map'])
            if self.compute_inverse:
                phi_inv = res[2]
                inv_phi_new = phi_inv
                if self.affine_back_to_original_postion:
                    print(
                        "Cannot compute the inverse map when affine back to the source image position"
                    )
                    return
                if size_diff:
                    inv_phi_new, _ = resample_image(phi_inv, input_spacing,
                                                    [1, 3] +
                                                    list(moving.shape[2:]))
                save_deformation(inv_phi_new, output_path,
                                 [fname + '_inv_map'])
Ejemplo n.º 7
0
    def _save_image_into_original_sz_with_given_reference(
            self, pair_path, phis, inverse_phis=None, use_01=False):
        """
        the images (moving, target, warped, transformation map, inverse transformation map world coord[0,1] ) are saved in record_path/original_sz

        :param pair_path: list, moving image path, target image path
        :param phis: transformation map BDXYZ
        :param inverse_phi: inverse transformation map BDXYZ
        :param use_01: indicate the transformation use [0,1] coord or [-1,1] coord
        :return:
        """
        save_original_resol_by_type = self.save_original_resol_by_type
        save_s, save_t, save_w, save_phi, save_w_inv, save_phi_inv, save_disp, save_extra_not_used_here = save_original_resol_by_type
        spacing = self.spacing
        moving_reference_list = pair_path[0]
        target_reference_list = pair_path[1]
        moving_l_reference_list = None
        target_l_reference_list = None
        if len(pair_path) == 4:
            moving_l_reference_list = pair_path[2]
            target_l_reference_list = pair_path[3]
        phis = (phis + 1) / 2. if not use_01 else phis
        saving_original_sz_path = os.path.join(self.record_path, 'original_sz')
        os.makedirs(saving_original_sz_path, exist_ok=True)
        for i in range(len(moving_reference_list)):
            moving_reference = moving_reference_list[i]
            target_reference = target_reference_list[i]
            moving_l_reference = moving_l_reference_list[
                i] if moving_l_reference_list else None
            target_l_reference = target_l_reference_list[
                i] if target_l_reference_list else None
            fname = self.fname_list[i]
            phi = phis[i:i + 1]
            inverse_phi = inverse_phis[i:i +
                                       1] if inverse_phis is not None else None

            # new_phi, warped, warped_l, new_spacing = ires.resample_warped_phi_and_image(moving_reference, target_reference,
            #                                                                             moving_l_reference,target_l_reference, phi, spacing)
            new_phi, warped, warped_l, new_spacing = ires.resample_warped_phi_and_image(
                moving_reference, target_reference, moving_l_reference,
                target_l_reference, phi, spacing)
            if save_phi or save_disp:
                if save_phi:
                    ires.save_transfrom(new_phi, new_spacing,
                                        saving_original_sz_path, [fname])
                if save_disp:
                    cur_fname = fname + '_disp'
                    id_map = gen_identity_map(warped.shape[2:],
                                              resize_factor=1.,
                                              normalized=True).cuda()
                    id_map = (id_map[None] + 1) / 2.
                    disp = new_phi - id_map
                    ires.save_transform_with_reference(
                        disp,
                        new_spacing, [moving_reference], [target_reference],
                        path=saving_original_sz_path,
                        fname_list=[cur_fname],
                        save_disp_into_itk_format=True)
                    del id_map, disp
            del new_phi, phi
            if save_w:
                cur_fname = fname + '_warped'
                ires.save_image_with_given_reference(warped,
                                                     [target_reference],
                                                     saving_original_sz_path,
                                                     [cur_fname])
                if warped_l is not None:
                    cur_fname = fname + '_warped_l'
                    ires.save_image_with_given_reference(
                        warped_l, [target_l_reference],
                        saving_original_sz_path, [cur_fname])
            del warped
            if save_s:
                cur_fname = fname + '_moving'
                ires.save_image_with_given_reference(None, [moving_reference],
                                                     saving_original_sz_path,
                                                     [cur_fname])
                if moving_l_reference is not None:
                    cur_fname = fname + '_moving_l'
                    ires.save_image_with_given_reference(
                        None, [moving_l_reference], saving_original_sz_path,
                        [cur_fname])
            if save_t:
                cur_fname = fname + '_target'
                ires.save_image_with_given_reference(None, [target_reference],
                                                     saving_original_sz_path,
                                                     [cur_fname])
                if target_l_reference is not None:
                    cur_fname = fname + '_target_l'
                    ires.save_image_with_given_reference(
                        None, [target_l_reference], saving_original_sz_path,
                        [cur_fname])
            if inverse_phi is not None:
                inverse_phi = (inverse_phi +
                               1) / 2. if not use_01 else inverse_phi
                new_inv_phi, inv_warped, inv_warped_l, new_spacing = ires.resample_warped_phi_and_image(
                    target_reference, moving_reference, target_l_reference,
                    moving_l_reference, inverse_phi, spacing)
                if save_phi_inv:
                    cur_fname = fname + '_inv'
                    ires.save_transfrom(new_inv_phi, new_spacing,
                                        saving_original_sz_path, [cur_fname])
                if save_w_inv:
                    cur_fname = fname + '_inv_warped'
                    ires.save_image_with_given_reference(
                        inv_warped, [moving_reference],
                        saving_original_sz_path, [cur_fname])
                    if moving_l_reference is not None:
                        cur_fname = fname + '_inv_warped_l'
                        ires.save_image_with_given_reference(
                            inv_warped_l, [moving_l_reference],
                            saving_original_sz_path, [cur_fname])
                if save_disp:
                    cur_fname = fname + '_inv_disp'
                    id_map = gen_identity_map(inv_warped.shape[2:],
                                              resize_factor=1.,
                                              normalized=True).cuda()
                    id_map = (id_map[None] + 1) / 2.
                    inv_disp = new_inv_phi - id_map
                    ires.save_transform_with_reference(
                        inv_disp,
                        new_spacing, [target_reference], [moving_reference],
                        path=saving_original_sz_path,
                        fname_list=[cur_fname],
                        save_disp_into_itk_format=True)
                    del id_map, inv_disp
                del new_inv_phi, inv_warped, inverse_phi