예제 #1
0
파일: RegNet3D.py 프로젝트: zlinzju/RegNet
def backup_script(setting):
    """
    This script does
    1. Update the current_experiment name based on stage and deform_exp name.
    2. Make a backup of the whole code
    3. Start logging.
    :param setting
    :return setting
    """
    date_now = datetime.datetime.now()
    network_name = ''
    if 'crop' in setting['NetworkDesign']:
        network_name = setting['NetworkDesign'].rsplit('_', 1)[0]
    if 'decimation' in setting['NetworkDesign']:
        network_name = 'dec' + setting['NetworkDesign'][-1]
    if 'unet' in setting['NetworkDesign']:
        network_name = 'unet' + setting['NetworkDesign'][-1]
    current_experiment = '{:04d}{:02d}{:02d}_{:02d}{:02d}{:02d}'.format(date_now.year, date_now.month, date_now.day, date_now.hour, date_now.minute, date_now.second) +\
        '_'+setting['DataExpDict'][0]['deform_exp']+'_'+setting['AGMode']+'_S'+str(setting['stage'])+'_'+network_name

    setting['current_experiment'] = current_experiment
    if not os.path.isdir(su.address_generator(setting, 'ModelFolder')):
        os.makedirs(su.address_generator(setting, 'ModelFolder'))

    shutil.copy(Path(__file__), su.address_generator(setting, 'ModelFolder'))
    shutil.copytree(Path(__file__).parent / Path('functions'), Path(su.address_generator(setting, 'ModelFolder')) / Path('functions'))
    gut.logger.set_log_file(su.address_generator(setting, 'LogFile'))
    return setting
예제 #2
0
def lung_fill_hole_erode(setting, cn=None):
    folder = su.address_generator(setting,
                                  'Lung_Filled_Erode',
                                  cn=cn,
                                  type_im=0).rsplit('/', maxsplit=1)[0]
    if not os.path.isdir(folder):
        os.makedirs(folder)
    for type_im in range(len(setting['types'])):
        lung_raw_filled_sitk = sitk.ReadImage(
            su.address_generator(setting,
                                 'Lung_Filled',
                                 cn=cn,
                                 type_im=type_im))
        lung_raw_filled = sitk.GetArrayFromImage(lung_raw_filled_sitk)

        lung_raw_filled = lung_raw_filled > 0
        structure = np.ones([3, 3, 3], dtype=np.bool)
        lung_filled_erode = (ndimage.morphology.binary_dilation(
            lung_raw_filled, structure=structure)).astype(np.int8)
        sitk.WriteImage(
            ip.array_to_sitk(lung_filled_erode, im_ref=lung_raw_filled_sitk),
            su.address_generator(setting,
                                 'Lung_Filled_Erode',
                                 cn=cn,
                                 type_im=type_im))
예제 #3
0
def cylinder_mask(setting, cn=None, overwrite=False):
    cylinder_folder = su.address_generator(setting,
                                           'Cylinder',
                                           cn=cn,
                                           type_im=0).rsplit('/',
                                                             maxsplit=1)[0]
    if not os.path.isdir(cylinder_folder):
        os.makedirs(cylinder_folder)
    for type_im in range(len(setting['types'])):
        cylinder_mask_address = su.address_generator(setting,
                                                     'Cylinder',
                                                     cn=cn,
                                                     type_im=type_im)
        if (not os.path.isfile(cylinder_mask_address)) or overwrite:
            image_sitk = sitk.ReadImage(
                su.address_generator(setting, 'Im', cn=cn, type_im=type_im))
            cylinder_mask_sitk = sitk.BinaryThreshold(
                image_sitk,
                lowerThreshold=setting['DefaultPixelValue'] - 1,
                upperThreshold=setting['DefaultPixelValue'] + 0.01,
                insideValue=0,
                outsideValue=1)
            structure = np.ones((1, 3, 3))
            # erosion with ndimage is 5 times faster than SimpleITK
            cylinder_mask_eroded = (ndimage.binary_erosion(
                sitk.GetArrayFromImage(cylinder_mask_sitk),
                structure=structure,
                iterations=2)).astype(np.int8)
            cylinder_mask_eroded_sitk = ip.array_to_sitk(cylinder_mask_eroded,
                                                         im_ref=image_sitk)
            sitk.WriteImage(cylinder_mask_eroded_sitk, cylinder_mask_address)
            logging.debug(cylinder_mask_address + ' is done')
예제 #4
0
파일: transform.py 프로젝트: zlinzju/RegNet
def bsplin_transformix_dvf(setting, pair_info, stage=1, overwrite=False):
    im_info_fixed = copy.deepcopy(pair_info[0])
    im_info_fixed['stage'] = stage
    dvf_bspline_address = su.address_generator(setting,
                                               'DVFBSpline',
                                               pair_info=pair_info,
                                               **im_info_fixed)
    if os.path.isfile(dvf_bspline_address):
        if overwrite:
            logging.debug('BSpline transformix overwriting... Data=' +
                          pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                              pair_info[0]['cn'], pair_info[0]['type_im']))
        else:
            logging.debug('BSpline transformix skipping... Data=' +
                          pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                              pair_info[0]['cn'], pair_info[0]['type_im']))
            return 0
    else:
        logging.debug('BSpline transformix starting... Data=' +
                      pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                          pair_info[0]['cn'], pair_info[0]['type_im']))

    bspline_folder = su.address_generator(setting,
                                          'BSplineFolder',
                                          pair_info=pair_info,
                                          **im_info_fixed)
    elxpy.transformix(parameter_file=su.address_generator(
        setting,
        'BSplineOutputParameter',
        pair_info=pair_info,
        **im_info_fixed),
                      input_image=None,
                      output_directory=bspline_folder,
                      points='all',
                      threads=setting['Reg_NumberOfThreads'])
예제 #5
0
def calculate_jacobian(setting, pair_info, overwrite=False):
    stage_list = setting['ImagePyramidSchedule']
    if setting['current_experiment'] == 'elx_registration':
        dvf_name = 'DVFBSpline'
        jac_name = 'DVFBSpline_Jac'
    else:
        dvf_name = 'dvf_s0'
        jac_name = 'dvf_s0_jac'
    jac_address = su.address_generator(setting,
                                       jac_name,
                                       pair_info=pair_info,
                                       stage_list=stage_list)

    if overwrite or not os.path.isfile(jac_address):
        time_before_jac = time.time()
        dvf0_address = su.address_generator(setting,
                                            dvf_name,
                                            pair_info=pair_info,
                                            stage_list=stage_list)
        dvf0_sitk = sitk.ReadImage(dvf0_address)
        dvf0 = sitk.GetArrayFromImage(dvf0_sitk)
        spacing = dvf0_sitk.GetSpacing()[::-1]
        jac = ip.calculate_jac(dvf0, spacing)
        sitk.WriteImage(sitk.GetImageFromArray(jac.astype(np.float32)),
                        jac_address)
        time_after_jac = time.time()
        logging.debug(pair_info[0]['data'] +
                      ', CN{}, ImType{} Jacobian is done in {:.2f}s '.format(
                          pair_info[0]['cn'], pair_info[0]['type_im'],
                          time_after_jac - time_before_jac))
    else:
        logging.debug(
            pair_info[0]['data'] +
            ', CN{}, ImType{} Jacobian is already available. skipping... '.
            format(pair_info[0]['cn'], pair_info[0]['type_im']))

        # jac_hist_max = 3
        # jac_hist_min = -1
        # step_h = 0.2
        # if np.max(jac) > jac_hist_max:
        #     jac_hist_max = np.ceil(np.max(jac))
        # if np.min(jac) < jac_hist_min:
        #     jac_hist_min = np.floor(np.min(jac))
        #
        # folding_percentage = np.sum(jac < 0) / np.prod(np.shape(jac)) * 100
        # import matplotlib.pyplot as plt
        # plt.figure()
        # plt.hist(np.ravel(jac), log=True, bins=np.arange(jac_hist_min, jac_hist_max+step_h, step_h))
        # plt.title('min(Jac)={:.2f}, max(Jac)={:.2f}, folding={:.5f}%'.format(np.min(jac), np.max(jac), folding_percentage))
        # plt.draw()
        # plt.savefig(su.address_generator(setting, 'dvf_s0_jac_hist_plot', pair_info=pair_info, stage_list=stage_list))
        # plt.close()

    # write histograms!
    return 0
예제 #6
0
파일: transform.py 프로젝트: zlinzju/RegNet
def base_reg_transformix_points(setting,
                                pair_info,
                                stage=1,
                                overwrite=False,
                                base_reg=None):
    """
    In this function we transform the points (index or world) by affine transform. This function
    utilizes transformix. However, it is also possible to read the affine parameters and do the math.
    :param setting:
    :param pair_info:
    :param stage:
    :param overwrite:
    :return:
    """
    im_info_fixed = copy.deepcopy(pair_info[0])
    im_info_fixed['stage'] = stage

    base_reg_output_points = su.address_generator(setting,
                                                  'Reg_BaseReg_OutputPoints',
                                                  pair_info=pair_info,
                                                  base_reg=base_reg,
                                                  **im_info_fixed)
    if os.path.isfile(base_reg_output_points):
        if overwrite:
            logging.debug(base_reg + ' transformix overwriting... Data=' +
                          pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                              pair_info[0]['cn'], pair_info[0]['type_im']))
        else:
            logging.debug(base_reg + ' transformix skipping... Data=' +
                          pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                              pair_info[0]['cn'], pair_info[0]['type_im']))
            return 0
    else:
        logging.debug(base_reg + ' transformix starting... Data=' +
                      pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                          pair_info[0]['cn'], pair_info[0]['type_im']))
    fixed_landmarks_point_elx_address = su.address_generator(
        setting, 'LandmarkPoint_elx', pair_info=pair_info, **im_info_fixed)
    base_reg_folder = su.address_generator(setting,
                                           'BaseRegFolder',
                                           pair_info=pair_info,
                                           base_reg=base_reg,
                                           **im_info_fixed)
    elxpy.transformix(parameter_file=base_reg_folder +
                      'TransformParameters.0.txt',
                      output_directory=base_reg_folder,
                      points=fixed_landmarks_point_elx_address,
                      transformix_address='transformix',
                      threads=setting['Reg_NumberOfThreads'])
예제 #7
0
파일: utils.py 프로젝트: zlinzju/RegNet
def remove_redundant_images(setting, im_info, stage=1):
    """
    Remove DeformedDVF and DeformedImage and NextIm from stage 1
    :param setting:
    :param im_info:
    :param stage:
    :return:
    """

    im_info_su = {'data': im_info['data'], 'deform_exp': im_info['deform_exp'], 'type_im': im_info['type_im'],
                  'cn': im_info['cn'], 'dsmooth': im_info['dsmooth'], 'stage': stage}

    im_list_remove = list()
    im_list_remove.append(su.address_generator(setting, 'DeformedDVF', **im_info_su))
    im_list_remove.append(su.address_generator(setting, 'Jac', **im_info_su))

    deformed_im_ext_combined = []
    for i_ext, deformed_im_ext_current in enumerate(im_info['deformed_im_ext']):
        deformed_im_ext_combined.append(deformed_im_ext_current)
        deformed_im_address = su.address_generator(setting, 'DeformedIm', deformed_im_ext=deformed_im_ext_combined, **im_info_su)
        im_list_remove.append(deformed_im_address)

    im_list_remove.append(su.address_generator(setting, 'DeformedTorso', **im_info_su))
    im_list_remove.append(su.address_generator(setting, 'NextIm', **im_info_su))
    im_list_remove.append(su.address_generator(setting, 'NextTorso', **im_info_su))
    im_list_remove.append(su.address_generator(setting, 'NextLung', **im_info_su))

    for im_address in im_list_remove:
        if os.path.isfile(im_address):
            os.remove(im_address)
예제 #8
0
def load_landmarks(setting, pair_info_list, experiment_list_new):
    """
    It loads the landmarks file from different experiments. each experiment has a separate landmarks file. It also calculates the 'Error' and 'TRE' for each pair
    For each pair it creates a dictionary with two keys:
        'pair_info': a copy of pair_info
        'landmark_info: a copy of all keys in that pair +  'Error' + 'TRE'

    :param setting:
    :param pair_info_list:
    :param experiment_list_new:
    :return: landmarks_dict: a dictionary of different experiments:
                             structure: landmarks_dict['experiment1'][list of all pairs]['pair_info', 'landmark_info']
    """
    landmarks_dict = dict()
    for exp_dict in experiment_list_new:
        exp_pure = exp_dict['experiment'].rsplit('/')[1]
        stage_list = exp_dict['stage_list']
        exp = exp_dict['experiment']

        exp_folder = exp.split('-')[0]
        exp_key_name = exp + '_' + exp_dict['BaseReg']
        landmark_address = su.address_generator(
            setting,
            'landmarks_file',
            current_experiment=exp_folder,
            stage_list=stage_list,
            base_reg=exp_dict['BaseReg'],
            step=exp_dict['GlobalStepLoad'])
        with open(landmark_address, 'rb') as f:
            landmarks_load = dill.load(f)
            # landmarks_load = dill.load(f)
        landmarks_dict[exp_key_name] = []
        for pair_info in pair_info_list:
            landmark_pair = dict()
            pair_info_text = exp_key_name + ' Fixed: ' + pair_info[0]['data'] + \
                '_CN{}_TypeIm{},'.format(pair_info[0]['cn'], pair_info[0]['type_im']) + '  Moving:' + \
                pair_info[1]['data'] + '_CN{}_TypeIm{}'.format(pair_info[1]['cn'], pair_info[1]['type_im'])
            print('loading ' + pair_info_text)
            ind_find_list = [
                compare_pair_info_dict(pair_info,
                                       landmark_i['pair_info'],
                                       compare_keys=['data', 'cn', 'type_im'])
                for landmark_i in landmarks_load
            ]
            if any(ind_find_list):
                ind_find = ind_find_list.index(True)
                landmark_pair = copy.deepcopy(landmarks_load[ind_find])
            else:
                print('landmark not found in experiment:' + pair_info_text)

            landmark_all_info = {
                'pair_info': copy.deepcopy(pair_info),
                'landmark_info': landmark_pair
            }

            landmarks_dict[exp_key_name].append(landmark_all_info)
    return landmarks_dict
예제 #9
0
def background_to_zero_linear(setting,
                              im_info_su,
                              gonna_generate_next_im=False):
    if gonna_generate_next_im:
        im_info_su_orig = copy.deepcopy(im_info_su)
        im_info_su_orig['dsmooth'] = 0
        torso_address = su.address_generator(setting, 'Torso',
                                             **im_info_su_orig)
    else:
        torso_address = su.address_generator(setting, 'Torso', **im_info_su)

    torso_im = sitk.GetArrayFromImage(sitk.ReadImage(torso_address))
    torso_distance = ndimage.morphology.distance_transform_edt(
        1 - torso_im, sampling=setting['VoxelSize'])
    mask_to_zero = torso_im.copy().astype(np.float)
    background_ind = [torso_im == 0]
    mask_to_zero[background_ind] = (1 / torso_distance[background_ind])
    mask_to_zero[mask_to_zero < 0.05] = 0
    return mask_to_zero
예제 #10
0
def resampling(data, spacing=None, requested_im_list=None):
    if spacing is None:
        spacing = [1, 1, 1]
    if requested_im_list is None:
        requested_im_list = ['Im']
    data_exp_dict = [{'data': data, 'deform_exp': '3D_max25_D12'}]
    setting = su.initialize_setting('')
    setting = su.load_setting_from_data_dict(setting, data_exp_dict)

    for type_im in range(len(setting['data'][data]['types'])):
        for cn in setting['data'][data]['CNList']:
            im_info_su = {
                'data': data,
                'type_im': type_im,
                'cn': cn,
                'stage': 1
            }
            for requested_im in requested_im_list:
                if requested_im == 'Im':
                    interpolator = sitk.sitkBSpline
                elif requested_im in ['Lung', 'Torso']:
                    interpolator = sitk.sitkNearestNeighbor
                else:
                    raise ValueError(
                        'interpolator is only defined for ["Im", "Mask", "Torso"] not for '
                        + requested_im)
                im_raw_sitk = sitk.ReadImage(
                    su.address_generator(setting,
                                         'Original' + requested_im + 'Raw',
                                         **im_info_su))
                im_resampled_sitk = ip.resampler_sitk(
                    im_raw_sitk,
                    spacing=spacing,
                    default_pixel_value=setting['data'][data]
                    ['DefaultPixelValue'],
                    interpolator=interpolator,
                    dimension=3)
                sitk.WriteImage(
                    im_resampled_sitk,
                    su.address_generator(setting, 'Original' + requested_im,
                                         **im_info_su))
                print(data + '_TypeIm' + str(type_im) + '_CN' + str(cn) + '_' +
                      requested_im + ' resampled to ' + str(spacing) + ' mm')
예제 #11
0
def dvf_statistics(setting, dvf, spacing=None, im_info=None, stage=None):
    # input is the dvf in numpy array.
    im_info_su = {
        'data': im_info['data'],
        'deform_exp': im_info['deform_exp'],
        'type_im': im_info['type_im'],
        'cn': im_info['cn'],
        'dsmooth': im_info['dsmooth'],
        'stage': stage,
        'padto': im_info['padto']
    }
    max_dvf = np.max(setting['deform_exp'][im_info['deform_exp']]['MaxDeform'])
    import matplotlib.pyplot as plt
    plt.figure()
    plt.hist(np.ravel(dvf), log=True, bins=np.arange(-max_dvf, max_dvf + 1))
    # this range is fine, because in the code the DVF will be normolized to be in range of ()
    plt.draw()
    plt.savefig(su.address_generator(setting, 'DVF_histogram', **im_info_su))
    plt.close()

    jac = ip.calculate_jac(dvf, spacing)
    sitk.WriteImage(sitk.GetImageFromArray(jac.astype(np.float32)),
                    su.address_generator(setting, 'Jac', **im_info_su))
    jac_hist_max = 3
    jac_hist_min = -1
    step_h = 0.2
    if np.max(jac) > jac_hist_max:
        jac_hist_max = np.ceil(np.max(jac))
    if np.min(jac) < jac_hist_min:
        jac_hist_min = np.floor(np.min(jac))

    plt.figure()
    plt.hist(np.ravel(jac),
             log=True,
             bins=np.arange(jac_hist_min, jac_hist_max + step_h, step_h))
    plt.title('min(Jac)={:.2f}, max(Jac)={:.2f}'.format(
        np.min(jac), np.max(jac)))
    plt.draw()
    plt.savefig(su.address_generator(setting, 'Jac_histogram', **im_info_su))
    plt.close()
예제 #12
0
def do_mask_to_zero_gaussian(setting, im_info_su, dvf, mask_to_zero, stage,
                             max_deform, sigma):
    mask_address = su.address_generator(setting, mask_to_zero, **im_info_su)
    mask_im = sitk.GetArrayFromImage(sitk.ReadImage(mask_address))
    dvf = dvf * np.repeat(
        np.expand_dims(mask_im, axis=3), np.shape(dvf)[3], axis=3)
    sigma = sigma / stage * max_deform / 7  # in stage 4 we should make this sigma smaller but at the same time
    sigma = np.tile(sigma, 3)
    # the max_deform in stage 4 is 20 which leads to negative jacobian. There is no problem for other sigma values in the code.
    dvf = smooth_dvf(dvf,
                     sigma_blur=sigma,
                     parallel_processing=setting['ParallelSearching'])
    return dvf
예제 #13
0
    def generate_chunk_only(self):
        # only in 1stEpoch mode
        while self._semi_epoch == 0:
            ishuffled_folder = reading_utils.get_ishuffled_folder_write_ishuffled_setting(
                self._setting,
                self._train_mode,
                self._stage,
                self._number_of_images_per_chunk,
                self._samples_per_image,
                self._im_info_list_full,
                full_image=self._full_image,
                chunk_length_force_to_multiple_of=self.
                _chunk_length_force_to_multiple_of)
            ishuffled_name = su.address_generator(self._setting,
                                                  'IShuffledName',
                                                  semi_epoch=self._semi_epoch,
                                                  chunk=self._chunk)
            ishuffled_address = ishuffled_folder + ishuffled_name

            while os.path.isfile(ishuffled_address) and not self._semi_epoch:
                logging.debug(
                    self._class_mode +
                    ': stage={}, SemiEpoch={}, Chunk={} is already generated, going to next chunk'
                    .format(self._stage, self._semi_epoch, self._chunk))
                self.go_to_next_chunk_without_going_to_fill()
                ishuffled_name = su.address_generator(
                    self._setting,
                    'IShuffledName',
                    semi_epoch=self._semi_epoch,
                    chunk=self._chunk)
                ishuffled_address = ishuffled_folder + ishuffled_name

            if not self._semi_epoch:
                self.fill()
                self.go_to_next_chunk()
        logging.debug(
            self._class_mode +
            ': exiting . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .'
        )
예제 #14
0
def landmarks_from_dvf_old(setting, IN_test_list):
    # %%------------------------------------------- Setting of generating synthetic DVFs------------------------------------------
    saved_file = su.address_generator(setting, 'landmarks_file')
    if os.path.isfile(saved_file):
        raise ValueError(
            'cannot overwrite, please change the name of the pickle file: ' +
            saved_file)

    # %%------------------------------------------------------  running   ---------------------------------------------------------
    landmarks = [{} for _ in range(22)]
    for IN in IN_test_list:
        image_pair = real_pair.Images(IN, setting=setting)
        image_pair.prepare_for_landmarks(padding=False)
        landmarks[IN][
            'FixedLandmarksWorld'] = image_pair._fixed_landmarks_world.copy()
        landmarks[IN][
            'MovingLandmarksWorld'] = image_pair._moving_landmarks_world.copy(
            )
        landmarks[IN][
            'FixedAfterAffineLandmarksWorld'] = image_pair._fixed_after_affine_landmarks_world.copy(
            )
        landmarks[IN]['DVFAffine'] = image_pair._dvf_affine.copy()
        landmarks[IN][
            'DVF_nonrigidGroundTruth'] = image_pair._moving_landmarks_world - image_pair._fixed_after_affine_landmarks_world
        landmarks[IN][
            'FixedLandmarksIndex'] = image_pair._fixed_landmarks_index
        dvf_s0_sitk = sitk.ReadImage(
            su.address_generator(setting, 'dvf_s0', cn=IN))
        dvf_s0 = sitk.GetArrayFromImage(dvf_s0_sitk)
        landmarks[IN]['DVFRegNet'] = np.stack([
            dvf_s0[image_pair._fixed_landmarks_index[i, 2],
                   image_pair._fixed_landmarks_index[i, 1],
                   image_pair._fixed_landmarks_index[i, 0]]
            for i in range(len(image_pair._fixed_landmarks_index))
        ])
        print('CPU: IN = {} is done in '.format(IN))
    with open(saved_file, 'wb') as f:
        pickle.dump(landmarks, f)
예제 #15
0
파일: transform.py 프로젝트: zlinzju/RegNet
def affine(setting, pair_info, parameter_name, stage=1, overwrite=False):

    im_info_fixed = copy.deepcopy(pair_info[0])
    im_info_moving = copy.deepcopy(pair_info[1])
    im_info_fixed['stage'] = stage
    im_info_moving['stage'] = stage

    affine_moved_address = su.address_generator(setting,
                                                'MovedImBaseReg',
                                                pair_info=pair_info,
                                                **im_info_fixed)
    if os.path.isfile(affine_moved_address):
        if overwrite:
            logging.debug('Affine registration overwriting... Data=' +
                          pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                              pair_info[0]['cn'], pair_info[0]['type_im']))
        else:
            logging.debug('Affine registration skipping... Data=' +
                          pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                              pair_info[0]['cn'], pair_info[0]['type_im']))
            return 0
    else:
        logging.debug('Affine registration starting... Data=' +
                      pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                          pair_info[0]['cn'], pair_info[0]['type_im']))
    fixed_mask_address = None
    moving_mask_address = None
    initial_transform = None
    if setting['Reg_BaseReg_Mask'] is not None:
        fixed_mask_address = su.address_generator(setting,
                                                  setting['Reg_BaseReg_Mask'],
                                                  **im_info_fixed)
        moving_mask_address = su.address_generator(setting,
                                                   setting['Reg_BaseReg_Mask'],
                                                   **im_info_moving)
    elxpy.elastix(parameter_file=su.address_generator(
        setting, 'ParameterFolder', pair_info=pair_info, **im_info_fixed) +
                  parameter_name,
                  output_directory=su.address_generator(setting,
                                                        'AffineFolder',
                                                        pair_info=pair_info,
                                                        **im_info_fixed),
                  elastix_address='elastix',
                  fixed_image=su.address_generator(setting, 'Im',
                                                   **im_info_fixed),
                  moving_image=su.address_generator(setting, 'Im',
                                                    **im_info_moving),
                  fixed_mask=fixed_mask_address,
                  moving_mask=moving_mask_address,
                  initial_transform=initial_transform,
                  threads=setting['Reg_NumberOfThreads'])
예제 #16
0
def initialize(current_experiment, stage_list, folder_script='functions'):
    parser = argparse.ArgumentParser(description='read where_to_run')
    parser.add_argument('--where_to_run', '-w',
                        help='This is an optional argument, '
                             'you choose between "Auto" or "Cluster". The default value is "Auto"')
    args = parser.parse_args()
    where_to_run = args.where_to_run

    setting = su.initialize_setting(current_experiment=current_experiment, where_to_run=where_to_run)
    date_now = datetime.datetime.now()
    backup_number = '{:04d}{:02d}{:02d}_{:02d}{:02d}{:02d}'. \
        format(date_now.year, date_now.month, date_now.day, date_now.hour, date_now.minute, date_now.second)
    backup_root_folder = su.address_generator(setting, 'result_step_folder', stage_list=stage_list) + 'CodeBackup/'
    backup_folder = backup_root_folder + 'backup-' + str(backup_number) + '/'
    gut.logger.set_log_file(backup_folder + 'log.txt', short_mode=True)
    shutil.copy(Path(__file__), Path(backup_folder) / Path(__file__).name)
    shutil.copytree(Path(__file__).parent / Path(folder_script), Path(backup_folder) / Path(folder_script))
    return setting, backup_folder
예제 #17
0
파일: real_pair.py 프로젝트: zlinzju/RegNet
    def load_pair(self):
        base_reg = self._setting['BaseReg']
        im_info_fixed = copy.deepcopy(self._pair_info[0])
        im_info_fixed['stage'] = self._stage
        fixed_mask_sitk = None
        moved_mask_affine_sitk = None
        if 'dsmooth' in im_info_fixed:
            # in this case it means that images are synthetic
            im_info_su = {'data': im_info_fixed['data'], 'deform_exp': im_info_fixed['deform_exp'], 'type_im': im_info_fixed['type_im'],
                          'cn': im_info_fixed['cn'], 'dsmooth': im_info_fixed['dsmooth'], 'stage': self._stage, 'padto': im_info_fixed['padto']}
            fixed_im_address = su.address_generator(self._setting, 'DeformedIm', deformed_im_ext=im_info_fixed['deformed_im_ext'], **im_info_su)
            moved_im_affine_address = su.address_generator(self._setting, 'Im', **im_info_su)
            if self._mask_to_zero is not None:
                fixed_mask_address = su.address_generator(self._setting, 'Deformed' + self._mask_to_zero, **im_info_su)
                moved_mask_affine_address = su.address_generator(self._setting, self._mask_to_zero, **im_info_su)

        else:
            im_info_moving = copy.deepcopy(self._pair_info[1])
            im_info_moving['stage'] = self._stage
            fixed_im_address = su.address_generator(self._setting, 'Im', **im_info_fixed)
            moved_im_affine_address = su.address_generator(self._setting, 'MovedImBaseReg', pair_info=self._pair_info,
                                                           base_reg=base_reg, **im_info_moving)
            if self._mask_to_zero is not None:
                fixed_mask_address = su.address_generator(self._setting, self._mask_to_zero, **im_info_fixed)
                moved_mask_affine_address = su.address_generator(self._setting, 'Moved'+self._mask_to_zero+'BaseReg',
                                                                 pair_info=self._pair_info,  base_reg=base_reg, **im_info_moving)

        fixed_im_sitk = sitk.ReadImage(fixed_im_address)
        moved_im_affine_sitk = sitk.ReadImage(moved_im_affine_address)
        logging.info('FixedIm:'+fixed_im_address)
        logging.info('MovedImBaseReg:' + moved_im_affine_address)
        if self._mask_to_zero is not None:
            fixed_mask_sitk = sitk.ReadImage(fixed_mask_address)
            moved_mask_affine_sitk = sitk.ReadImage(moved_mask_affine_address)
            logging.info('FixedMask:' + fixed_mask_address)
            logging.info('MovedMaskBaseReg:' + moved_mask_affine_address)

        return fixed_im_sitk, moved_im_affine_sitk, fixed_mask_sitk, moved_mask_affine_sitk
예제 #18
0
def check_all_images_exist(setting, im_info, stage, mask_to_zero=None):
    """
    This function check if all images are available and return a boolean.

    :param setting:
    :param im_info:
                    'data':
                    'deform_exp':
                    'type_im':
                    'CN":            Case Number: (Image Number )Please note that it starts from 1 not 0
                    'dsmooth':       This variable is used to generate another deformed version of the moving image.
                                     Then, use that image to make synthetic DVFs. More information available on [sokooti2017nonrigid]
                    'deform_method'
                    'deform_number'
    :param stage

    :return all_exist [boolean]
    """
    im_info_su = {
        'data': im_info['data'],
        'deform_exp': im_info['deform_exp'],
        'type_im': im_info['type_im'],
        'cn': im_info['cn'],
        'dsmooth': im_info['dsmooth'],
        'stage': stage,
        'padto': im_info['padto'],
        'deformed_im_ext': im_info['deformed_im_ext']
    }

    im_name_list = ['Im', 'DeformedIm', 'DeformedDVF']
    if mask_to_zero is not None:
        im_name_list = im_name_list + [mask_to_zero, 'Deformed' + mask_to_zero]

    all_exist = True
    for im_name in im_name_list:
        im_address = su.address_generator(setting, im_name, **im_info_su)
        if not os.path.isfile(im_address):
            all_exist = False
            break
    return all_exist
예제 #19
0
def add_sponge_model(setting,
                     im_info,
                     stage,
                     deformed_im_previous_sitk=None,
                     dvf=None,
                     deformed_torso_sitk=None,
                     spacing=None,
                     gonna_generate_next_im=False):
    im_info_su = {
        'data': im_info['data'],
        'deform_exp': im_info['deform_exp'],
        'type_im': im_info['type_im'],
        'cn': im_info['cn'],
        'dsmooth': im_info['dsmooth'],
        'stage': stage,
        'padto': im_info['padto']
    }
    seed_number = ag_utils.seed_number_by_im_info(
        im_info,
        'add_sponge_model',
        stage=stage,
        gonna_generate_next_im=gonna_generate_next_im)
    random_state = np.random.RandomState(seed_number)

    if gonna_generate_next_im:
        jac_address = su.address_generator(setting, 'NextJac', **im_info_su)
        torso_address = su.address_generator(setting, 'NextTorso',
                                             **im_info_su)
    else:
        jac_address = su.address_generator(setting, 'Jac', **im_info_su)
        torso_address = su.address_generator(setting, 'DeformedTorso',
                                             **im_info_su)

    if deformed_im_previous_sitk is None:
        deformed_im_previous_sitk = sitk.ReadImage(
            su.address_generator(setting,
                                 'DeformedIm',
                                 deformed_im_ext='Clean',
                                 **im_info_su))
    deformed_im_clean = sitk.GetArrayFromImage(deformed_im_previous_sitk)

    if not os.path.isfile(jac_address):
        if dvf is None:
            dvf = sitk.GetArrayFromImage(
                sitk.ReadImage(
                    su.address_generator(setting, 'DeformedDVF',
                                         **im_info_su)))
        if spacing is None:
            spacing = deformed_im_previous_sitk.GetSpacing()[::-1]
        jac = ip.calculate_jac(dvf, spacing)
        if not gonna_generate_next_im:
            sitk.WriteImage(sitk.GetImageFromArray(jac.astype(np.float32)),
                            jac_address)
    else:
        jac = sitk.GetArrayFromImage(sitk.ReadImage(jac_address))

    random_scale = random_state.uniform(0.9, 1.1)
    jac[jac < 0.7] = 0.7
    jac[jac > 1.3] = 1.3
    deformed_im_sponge = deformed_im_clean / jac * random_scale
    if setting['UseTorsoMask']:
        # no scaling outside of Torso region.
        if deformed_torso_sitk is None:
            deformed_torso_sitk = sitk.ReadImage(torso_address)
        deformed_torso = sitk.GetArrayFromImage(deformed_torso_sitk)
        deformed_im_previous = sitk.GetArrayFromImage(
            deformed_im_previous_sitk)
        deformed_im_sponge[deformed_torso == 0] = deformed_im_previous[
            deformed_torso == 0]
    deformed_im_sponge_sitk = ip.array_to_sitk(
        deformed_im_sponge, im_ref=deformed_im_previous_sitk)

    return deformed_im_sponge_sitk
예제 #20
0
def table_box_plot(setting,
                   landmarks,
                   exp_list,
                   fig_measure_list=None,
                   plot_per_pair=False,
                   fig_ext='.png',
                   plot_folder=None,
                   paper_table=None,
                   naming_strategy=None,
                   jacobian=False,
                   label_times2=None,
                   label_times1=None,
                   step=0,
                   xlx_name=None):
    """
    merge the landmarks from different cases
    :param setting:
    :param landmarks:
    :param compare_list:
    :param
    :param plot_folder: one of the experiment to save all plots in that directory. The plot folder should consider also the stages
    for example: my_experiment_S4_S2_S1. if None, the last experiment will be chosen as the plot folder
    :param naming_strategy: None
                            'Fancy'
                            'Clean'
    :param paper_table: 'SPREAD', 'DIR-Lab'
    :return:
    """

    landmarks_merged = dict()
    if plot_folder is None:
        plot_key = list(landmarks.items())[-1][0]
    else:
        plot_key = plot_folder
    if xlx_name is None:
        xlx_name = 'results'
    stage_list = [4, 2, 1]

    result_folder = su.address_generator(
        setting,
        'result_detail_folder',
        current_experiment=plot_key,
        stage_list=stage_list,
        step=step,
        pair_info=landmarks[plot_key + '_' +
                            exp_list[0]['BaseReg']][0]['pair_info'])
    if not os.path.isdir(result_folder):
        os.makedirs(result_folder)

    xlsx_address = result_folder + xlx_name + '.xlsx'
    # if os.path.isfile(xlsx_address):
    #     raise ValueError(xlsx_address + 'already exists cannot overwrite')
    workbook = xlsxwriter.Workbook(xlsx_address)
    worksheet = workbook.add_worksheet()
    line = 0
    header = {
        'exp': 0,
        'P0A0': 1,
        'P0A1': 2,
        'P0A2': 3,
        'P1A0': 4,
        'P1A1': 5,
        'P1A2': 6,
        'P2A0': 7,
        'P2A1': 8,
        'P2A2': 9,
        'A0': 10,
        'A1': 11,
        'A2': 12,
        'F0': 13,
        'F1': 14,
        'F2': 15,
        'Acc': 16
    }

    for key in header.keys():
        worksheet.write(line, header[key], key)
    num_exp = len(landmarks.keys())
    for exp_i, exp in enumerate(landmarks.keys()):
        landmarks_merged[exp] = {
            'DVF_error_times2_label': np.empty([0]),
            'DVF_error_times2_logits': np.empty([0,
                                                 setting['NumberOfLabels']]),
            'DVF_error_times1_label': np.empty([0]),
            'DVF_error_times1_logits': np.empty([0,
                                                 setting['NumberOfLabels']]),
            'DVF_error_times0_label': np.empty([0]),
            'DVF_error_times0_logits': np.empty([0,
                                                 setting['NumberOfLabels']]),
            'DVF_nonrigidGroundTruth_magnitude': np.empty([0]),
            'CleanName': su.clean_exp_name(exp),
            'FancyName': su.fancy_exp_name(exp),
        }

        num_pair = len(landmarks[exp])
        for pair_i, landmark_pair in enumerate(landmarks[exp]):
            pair_info = landmark_pair['pair_info']
            pair_info_text = landmarks_merged[exp]['CleanName'] + '_Fixed_' + pair_info[0]['data'] + \
                '_CN{}_TypeIm{},'.format(pair_info[0]['cn'], pair_info[0]['type_im']) + '_Moving_' + \
                pair_info[1]['data'] + '_CN{}_TypeIm{}'.format(pair_info[1]['cn'], pair_info[1]['type_im'])
            for i in range(3):
                if 'DVF_error_times' + str(i) + '_logits' in landmark_pair[
                        'landmark_info'].keys():
                    landmarks_merged[exp][
                        'DVF_error_times' + str(i) + '_logits'] = np.vstack(
                            (landmarks_merged[exp]['DVF_error_times' + str(i) +
                                                   '_logits'],
                             landmark_pair['landmark_info']['DVF_error_times' +
                                                            str(i) +
                                                            '_logits']))
                    landmarks_merged[exp][
                        'DVF_error_times' + str(i) + '_label'] = np.append(
                            landmarks_merged[exp]['DVF_error_times' + str(i) +
                                                  '_label'],
                            landmark_pair['landmark_info']['DVF_error_times' +
                                                           str(i) + '_label'])
            landmarks_merged[exp][
                'DVF_nonrigidGroundTruth_magnitude'] = np.append(
                    landmarks_merged[exp]['DVF_nonrigidGroundTruth_magnitude'],
                    landmark_pair['landmark_info']
                    ['DVF_nonrigidGroundTruth_magnitude'])

            measure = calculate_measure(landmark_pair['landmark_info'],
                                        label_times2=label_times2,
                                        label_times1=label_times1)
            measure['exp'] = pair_info_text
            if plot_per_pair:
                print_latex(measure)
            line = exp_i + pair_i * (num_exp + 1) + 1
            for key in header.keys():
                if key in measure.keys():
                    worksheet.write(line, header[key], measure[key])
                    landmark_pair['landmark_info'][key] = measure[key]

        measure_merged = calculate_measure(landmarks_merged[exp],
                                           label_times2=label_times2,
                                           label_times1=label_times1)
        if naming_strategy == 'Clean':
            measure_merged['exp'] = su.clean_exp_name(exp)
        elif naming_strategy == 'Fancy':
            measure_merged['exp'] = su.fancy_exp_name(exp)
        else:
            measure_merged['exp'] = exp
        # print_latex(measure_merged)
        line = exp_i + num_pair * (num_exp + 1) + 2
        for key in header.keys():
            if key in measure.keys():
                if key in header.keys() and key in measure_merged.keys():
                    worksheet.write(line, header[key], measure_merged[key])

    full_merge = {
        'DVF_error_times2_label': np.empty([0]),
        'DVF_nonrigidGroundTruth_magnitude': np.empty([0]),
        'DVF_error_times2_logits': np.empty([0, setting['NumberOfLabels']]),
    }
    for exp_i, exp in enumerate(landmarks_merged.keys()):
        full_merge['DVF_error_times2_logits'] = np.vstack(
            (full_merge['DVF_error_times2_logits'],
             landmarks_merged[exp]['DVF_error_times2_logits']))
        full_merge['DVF_error_times2_label'] = np.append(
            full_merge['DVF_error_times2_label'],
            landmarks_merged[exp]['DVF_error_times2_label'])
        full_merge['DVF_nonrigidGroundTruth_magnitude'] = np.append(
            full_merge['DVF_nonrigidGroundTruth_magnitude'],
            landmarks_merged[exp]['DVF_nonrigidGroundTruth_magnitude'])
    measure_full_merged = calculate_measure(full_merge,
                                            label_times2=label_times2,
                                            label_times1=label_times1)
    measure_full_merged['exp'] = 'Total'
    line = line + 1
    for key in header.keys():
        if key in measure_full_merged.keys():
            worksheet.write(line, header[key], measure_full_merged[key])
    workbook.close()
예제 #21
0
def calculate_write_landmark(setting,
                             pair_info,
                             overwrite_landmarks=False,
                             overwrite_landmarks_hard=False,
                             base_reg=None):
    """
    Add the following information:     [setting, 'pair_info', 'FixedLandmarksWorld', 'MovingLandmarksWorld', 'FixedAfterAffineLandmarksWorld'
                                        'DVFAffine', 'DVF_nonrigidGroundTruth', 'FixedLandmarksIndex']


    :param setting:
    :param pair_info:
    :param overwrite_landmarks:
    :param overwrite_bspline_dvf:
    :param overwrite_jac:
    :return:
    """
    time_before = time.time()
    stage_list = setting['ImagePyramidSchedule']
    im_info_fixed = copy.deepcopy(pair_info[0])
    im_info_fixed['stage'] = 1
    landmark_address = su.address_generator(
        setting,
        'landmarks_file',
        stage_list=stage_list,
        current_experiment=setting['lstm_exp'],
        step=setting['network_lstm_dict']['GlobalStepLoad'])
    if os.path.isfile(landmark_address):
        with open(landmark_address, 'rb') as f:
            landmark = dill.load(f)
    else:
        landmark = []
        result_landmarks_folder = su.address_generator(
            setting,
            'result_landmarks_folder',
            stage_list=stage_list,
            current_experiment=setting['lstm_exp'],
            step=setting['network_lstm_dict']['GlobalStepLoad'])
        if not os.path.isdir(result_landmarks_folder):
            os.makedirs(result_landmarks_folder)

    calculate_multi_stage_error = False

    pair_info_text = 'stage_list={}'.format(stage_list) + ' Fixed: ' + pair_info[0]['data'] + \
                     '_CN{}_TypeIm{},'.format(pair_info[0]['cn'], pair_info[0]['type_im']) + '  Moving:' + \
                     pair_info[1]['data'] + '_CN{}_TypeIm{}'.format(pair_info[1]['cn'], pair_info[1]['type_im'])
    ind_find_list = [
        compare_pair_info_dict(pair_info,
                               landmark_i['pair_info'],
                               compare_keys=['data', 'cn', 'type_im'])
        for landmark_i in landmark
    ]
    ind_find = None
    if any(ind_find_list):
        ind_find = ind_find_list.index(True)
        if not overwrite_landmarks:
            logging.info('Skipping ' + pair_info_text)
        else:
            calculate_multi_stage_error = True
            logging.info('overwriting ' + pair_info_text)

    if ind_find is None:
        ind_find = -1
        calculate_multi_stage_error = True

    # if overwrite_landmarks_hard:
    #     landmark_dict = landmark_info(setting, pair_info)
    #     keys_to_copy = ['FixedLandmarksWorld', 'MovingLandmarksWorld', 'FixedAfterAffineLandmarksWorld',
    #                     'DVFAffine', 'DVF_nonrigidGroundTruth', 'FixedLandmarksIndex']
    #     for key in keys_to_copy:
    #         landmark[ind_find][key] = copy.deepcopy(landmark_dict[key])

    if calculate_multi_stage_error or overwrite_landmarks_hard:
        landmark_dict = landmark_info(setting, pair_info, base_reg=base_reg)
        landmark_dict = reg.multi_stage_error(setting,
                                              landmark_dict,
                                              pair_info=pair_info)
        # landmark_dict = landmark_info(setting, pair_info)
        landmark_dict['network_dict'] = copy.deepcopy(setting['network_dict'])
        landmark_dict['network_lstm_dict'] = copy.deepcopy(
            setting['network_lstm_dict'])
        landmark.append(landmark_dict)

        with open(landmark_address, 'wb') as f:
            dill.dump(landmark, f)

    time_after = time.time()
    logging.debug('Landmark ' + pair_info_text +
                  ' is done in {:.2f}s '.format(time_after - time_before))

    return landmark
예제 #22
0
def multi_stage(setting, pair_info, overwrite=False):
    """
    :param setting:
    :param pair_info: information of the pair to be registered.
    :param overwrite:
    :return: The output moved images and dvf will be written to the disk.
             1: registration is performed correctly
             2: skip overwriting
             3: the dvf is available from the previous experiment [4, 2, 1]. Then just upsample it.
    """
    stage_list = setting['ImagePyramidSchedule']
    if setting['read_pair_mode'] == 'synthetic':
        deformed_im_ext = pair_info[0].get('deformed_im_ext', None)
        im_info_su = {
            'data': pair_info[0]['data'],
            'deform_exp': pair_info[0]['deform_exp'],
            'type_im': pair_info[0]['type_im'],
            'cn': pair_info[0]['cn'],
            'dsmooth': pair_info[0]['dsmooth'],
            'padto': pair_info[0]['padto'],
            'deformed_im_ext': deformed_im_ext
        }
        moved_im_s0_address = su.address_generator(setting,
                                                   'MovedIm_AG',
                                                   stage=1,
                                                   **im_info_su)
        moved_torso_s1_address = su.address_generator(setting,
                                                      'MovedTorso_AG',
                                                      stage=1,
                                                      **im_info_su)
        moved_lung_s1_address = su.address_generator(setting,
                                                     'MovedLung_AG',
                                                     stage=1,
                                                     **im_info_su)
    else:
        moved_im_s0_address = su.address_generator(setting,
                                                   'MovedIm',
                                                   pair_info=pair_info,
                                                   stage=0,
                                                   stage_list=stage_list)
        moved_torso_s1_address = None
        moved_lung_s1_address = None

    if setting['read_pair_mode'] == 'synthetic':
        if os.path.isfile(moved_im_s0_address) and os.path.isfile(
                moved_torso_s1_address):
            if not overwrite:
                logging.debug('overwrite=False, file ' + moved_im_s0_address +
                              ' already exists, skipping .....')
                return 2
            else:
                logging.debug('overwrite=True, file ' + moved_im_s0_address +
                              ' already exists, but overwriting .....')
    else:
        if os.path.isfile(moved_im_s0_address):
            if not overwrite:
                logging.debug('overwrite=False, file ' + moved_im_s0_address +
                              ' already exists, skipping .....')
                return 2
            else:
                logging.debug('overwrite=True, file ' + moved_im_s0_address +
                              ' already exists, but overwriting .....')

    pair_stage1 = real_pair.Images(
        setting, pair_info, stage=1, padto=setting['PadTo']
        ['stage1'])  # just read the original images without any padding
    pyr = dict()  # pyr: a dictionary of pyramid images
    pyr['fixed_im_s1_sitk'] = pair_stage1.get_fixed_im_sitk()
    pyr['moving_im_s1_sitk'] = pair_stage1.get_moved_im_affine_sitk()
    pyr['fixed_im_s1'] = pair_stage1.get_fixed_im()
    pyr['moving_im_s1'] = pair_stage1.get_moved_im_affine()
    if setting['UseMask']:
        pyr['fixed_mask_s1_sitk'] = pair_stage1.get_fixed_mask_sitk()
        pyr['moving_mask_s1_sitk'] = pair_stage1.get_moved_mask_affine_sitk()
    if setting['read_pair_mode'] == 'real':
        if not (os.path.isdir(
                su.address_generator(setting,
                                     'full_reg_folder',
                                     pair_info=pair_info,
                                     stage_list=stage_list))):
            os.makedirs(
                su.address_generator(setting,
                                     'full_reg_folder',
                                     pair_info=pair_info,
                                     stage_list=stage_list))
    setting['GPUMemory'], setting['NumberOfGPU'] = tfu.client.read_gpu_memory()
    time_before_dvf = time.time()

    # check if DVF is available from the previous experiment [4, 2, 1]. Then just upsample it.
    if stage_list in [[4, 2], [4]]:
        dvf0_address = su.address_generator(setting,
                                            'dvf_s0',
                                            pair_info=pair_info,
                                            stage_list=stage_list)
        chosen_stage = None
        if stage_list == [4, 2]:
            chosen_stage = 2
        elif stage_list == [4]:
            chosen_stage = 4
        if chosen_stage is not None:
            dvf_s_up_address = su.address_generator(setting,
                                                    'dvf_s_up',
                                                    pair_info=pair_info,
                                                    stage=chosen_stage,
                                                    stage_list=[4, 2, 1])
            if os.path.isfile(dvf_s_up_address):
                logging.debug('DVF found from prev exp:' + dvf_s_up_address +
                              ', only performing upsampling')
                dvf_s_up = sitk.ReadImage(dvf_s_up_address)
                dvf0 = ip.resampler_sitk(
                    dvf_s_up,
                    scale=1 / (chosen_stage / 2),
                    im_ref_size=pyr['fixed_im_s1_sitk'].GetSize(),
                    interpolator=sitk.sitkLinear)
                sitk.WriteImage(sitk.Cast(dvf0, sitk.sitkVectorFloat32),
                                dvf0_address)
                return 3

    for i_stage, stage in enumerate(setting['ImagePyramidSchedule']):
        mask_to_zero_stage = setting['network_dict']['stage' +
                                                     str(stage)]['MaskToZero']
        if stage != 1:
            pyr['fixed_im_s' + str(stage) + '_sitk'] = ip.downsampler_gpu(
                pyr['fixed_im_s1_sitk'],
                stage,
                default_pixel_value=setting['data'][
                    pair_info[0]['data']]['DefaultPixelValue'])
            pyr['moving_im_s' + str(stage) + '_sitk'] = ip.downsampler_gpu(
                pyr['moving_im_s1_sitk'],
                stage,
                default_pixel_value=setting['data'][
                    pair_info[1]['data']]['DefaultPixelValue'])
        if setting['UseMask']:
            pyr['fixed_mask_s' + str(stage) + '_sitk'] = ip.resampler_sitk(
                pyr['fixed_mask_s1_sitk'],
                scale=stage,
                im_ref=pyr['fixed_im_s' + str(stage) + '_sitk'],
                default_pixel_value=0,
                interpolator=sitk.sitkNearestNeighbor)
            pyr['moving_mask_s' + str(stage) + '_sitk'] = ip.resampler_sitk(
                pyr['moving_mask_s1_sitk'],
                scale=stage,
                im_ref=pyr['moving_im_s' + str(stage) + '_sitk'],
                default_pixel_value=0,
                interpolator=sitk.sitkNearestNeighbor)

            if setting['WriteMasksForLSTM']:
                # only to be used in sequential training (LSTM)
                if setting['read_pair_mode'] == 'synthetic':
                    fixed_mask_stage_address = su.address_generator(
                        setting,
                        'Deformed' + mask_to_zero_stage,
                        stage=stage,
                        **im_info_su)
                    moving_mask_stage_address = su.address_generator(
                        setting, mask_to_zero_stage, stage=stage, **im_info_su)
                    fixed_im_stage_address = su.address_generator(setting,
                                                                  'DeformedIm',
                                                                  stage=stage,
                                                                  **im_info_su)
                    sitk.WriteImage(
                        sitk.Cast(
                            pyr['fixed_im_s' + str(stage) + '_sitk'],
                            setting['data'][pair_info[1]['data']]
                            ['ImageByte']), fixed_im_stage_address)
                    sitk.WriteImage(pyr['fixed_mask_s' + str(stage) + '_sitk'],
                                    fixed_mask_stage_address)
                    if im_info_su['dsmooth'] != 0 and stage == 4:
                        # not overwirte original images
                        moving_im_stage_address = su.address_generator(
                            setting, 'Im', stage=stage, **im_info_su)
                        sitk.WriteImage(
                            sitk.Cast(
                                pyr['moving_im_s' + str(stage) + '_sitk'],
                                setting['data'][pair_info[1]['data']]
                                ['ImageByte']), moving_im_stage_address)
                        sitk.WriteImage(
                            pyr['moving_mask_s' + str(stage) + '_sitk'],
                            moving_mask_stage_address)
                else:
                    fixed_im_stage_address = su.address_generator(
                        setting, 'Im', stage=stage, **pair_info[0])
                    fixed_mask_stage_address = su.address_generator(
                        setting,
                        mask_to_zero_stage,
                        stage=stage,
                        **pair_info[0])
                    if not os.path.isfile(fixed_im_stage_address):
                        sitk.WriteImage(
                            sitk.Cast(
                                pyr['fixed_im_s' + str(stage) + '_sitk'],
                                setting['data'][pair_info[1]['data']]
                                ['ImageByte']), fixed_im_stage_address)
                    if not os.path.isfile(fixed_mask_stage_address):
                        sitk.WriteImage(
                            pyr['fixed_mask_s' + str(stage) + '_sitk'],
                            fixed_mask_stage_address)
                    if i_stage == 0:
                        moved_im_affine_stage_address = su.address_generator(
                            setting,
                            'MovedImBaseReg',
                            pair_info=pair_info,
                            stage=stage,
                            **pair_info[1])
                        moved_mask_affine_stage_address = su.address_generator(
                            setting,
                            'Moved' + mask_to_zero_stage + 'BaseReg',
                            pair_info=pair_info,
                            stage=stage,
                            **pair_info[1])
                        if not os.path.isfile(moved_im_affine_stage_address):
                            sitk.WriteImage(
                                sitk.Cast(
                                    pyr['moving_im_s' + str(stage) + '_sitk'],
                                    setting['data'][pair_info[1]
                                                    ['data']]['ImageByte']),
                                moved_im_affine_stage_address)
                        if not os.path.isfile(moved_mask_affine_stage_address):
                            sitk.WriteImage(
                                pyr['moving_mask_s' + str(stage) + '_sitk'],
                                moved_mask_affine_stage_address)

        else:
            pyr['fixed_mask_s' + str(stage) + '_sitk'] = None
            pyr['moving_mask_s' + str(stage) + '_sitk'] = None
        input_regnet_moving_mask = None
        if i_stage == 0:
            input_regnet_moving = 'moving_im_s' + str(stage) + '_sitk'
            if setting['UseMask']:
                input_regnet_moving_mask = 'moving_mask_s' + str(
                    stage) + '_sitk'
        else:
            previous_pyramid = setting['ImagePyramidSchedule'][i_stage - 1]
            dvf_composed_previous_up_sitk = 'DVF_s' + str(
                previous_pyramid) + '_composed_up_sitk'
            dvf_composed_previous_sitk = 'DVF_s' + str(
                previous_pyramid) + '_composed_sitk'
            if i_stage == 1:
                pyr[dvf_composed_previous_sitk] = pyr[
                    'DVF_s' +
                    str(setting['ImagePyramidSchedule'][i_stage - 1]) +
                    '_sitk']
            elif i_stage > 1:
                pyr[dvf_composed_previous_sitk] = sitk.Add(
                    pyr['DVF_s' +
                        str(setting['ImagePyramidSchedule'][i_stage - 2]) +
                        '_composed_up_sitk'],
                    pyr['DVF_s' +
                        str(setting['ImagePyramidSchedule'][i_stage - 1]) +
                        '_sitk'])
            pyr[dvf_composed_previous_up_sitk] = ip.upsampler_gpu(
                pyr[dvf_composed_previous_sitk],
                round(previous_pyramid / stage),
                output_shape_3d=pyr['fixed_im_s' + str(stage) +
                                    '_sitk'].GetSize()[::-1],
            )
            if setting['WriteAfterEachStage'] and not setting['WriteNoDVF']:
                sitk.WriteImage(
                    sitk.Cast(pyr[dvf_composed_previous_up_sitk],
                              sitk.sitkVectorFloat32),
                    su.address_generator(setting,
                                         'dvf_s_up',
                                         pair_info=pair_info,
                                         stage=previous_pyramid,
                                         stage_list=stage_list))

            dvf_t = sitk.DisplacementFieldTransform(
                pyr[dvf_composed_previous_up_sitk])
            # after this line DVF_composed_previous_up_sitk is converted to a transform. so we need to load it again.
            pyr['moved_im_s' + str(stage) +
                '_sitk'] = ip.resampler_by_transform(
                    pyr['moving_im_s' + str(stage) + '_sitk'],
                    dvf_t,
                    default_pixel_value=setting['data'][
                        pair_info[1]['data']]['DefaultPixelValue'])
            if setting['UseMask']:
                pyr['moved_mask_s' + str(stage) +
                    '_sitk'] = ip.resampler_by_transform(
                        pyr['moving_mask_s' + str(stage) + '_sitk'],
                        dvf_t,
                        default_pixel_value=0,
                        interpolator=sitk.sitkNearestNeighbor)

            pyr[dvf_composed_previous_up_sitk] = dvf_t.GetDisplacementField()
            if setting['WriteAfterEachStage']:
                if setting['read_pair_mode'] == 'synthetic':
                    moved_im_s_address = su.address_generator(setting,
                                                              'MovedIm_AG',
                                                              stage=stage,
                                                              **im_info_su)
                    moved_mask_s_address = su.address_generator(
                        setting,
                        'Moved' + mask_to_zero_stage + '_AG',
                        stage=stage,
                        **im_info_su)
                else:
                    moved_im_s_address = su.address_generator(
                        setting,
                        'MovedIm',
                        pair_info=pair_info,
                        stage=stage,
                        stage_list=stage_list)
                    moved_mask_s_address = su.address_generator(
                        setting,
                        'Moved' + mask_to_zero_stage,
                        pair_info=pair_info,
                        stage=stage,
                        stage_list=stage_list)

                sitk.WriteImage(
                    sitk.Cast(
                        pyr['moved_im_s' + str(stage) + '_sitk'],
                        setting['data'][pair_info[1]['data']]['ImageByte']),
                    moved_im_s_address)

                if setting['WriteMasksForLSTM']:
                    sitk.WriteImage(pyr['moved_mask_s' + str(stage) + '_sitk'],
                                    moved_mask_s_address)

            input_regnet_moving = 'moved_im_s' + str(stage) + '_sitk'
            if setting['UseMask']:
                input_regnet_moving_mask = 'moved_mask_s' + str(
                    stage) + '_sitk'

        pyr['DVF_s' + str(stage)] = np.zeros(
            np.r_[pyr['fixed_im_s' + str(stage) + '_sitk'].GetSize()[::-1], 3],
            dtype=np.float64)
        if setting['network_dict']['stage'+str(stage)]['R'] == 'Auto' and \
                setting['network_dict']['stage'+str(stage)]['Ry'] == 'Auto':
            current_network_name = setting['network_dict'][
                'stage' + str(stage)]['NetworkDesign']
            r_out_erode_default = setting['network_dict'][
                'stage' + str(stage)]['Ry_erode']
            r_in, r_out, r_out_erode = network.utils.find_optimal_radius(
                pyr['fixed_im_s' + str(stage) + '_sitk'],
                current_network_name,
                r_out_erode_default,
                gpu_memory=setting['GPUMemory'],
                number_of_gpu=setting['NumberOfGPU'])

        else:
            r_in = setting['network_dict']['stage' + str(stage)][
                'R']  # Radius of normal resolution patch size. Total size is (2*R +1)
            r_out = setting['network_dict']['stage' + str(stage)][
                'Ry']  # Radius of output. Total size is (2*Ry +1)
            r_out_erode = setting['network_dict']['stage' + str(stage)][
                'Ry_erode']  # at the test time, sometimes there are some problems at the border

        logging.debug(
            'stage' + str(stage) + ' ,' + pair_info[0]['data'] +
            ', CN{}, ImType{}, Size={}'.format(
                pair_info[0]['cn'], pair_info[0]['type_im'], pyr[
                    'fixed_im_s' + str(stage) + '_sitk'].GetSize()[::-1]) +
            ', ' +
            setting['network_dict']['stage' + str(stage)]['NetworkDesign'] +
            ': r_in:{}, r_out:{}, r_out_erode:{}'.format(
                r_in, r_out, r_out_erode))
        pair_pyramid = real_pair.Images(
            setting,
            pair_info,
            stage=stage,
            fixed_im_sitk=pyr['fixed_im_s' + str(stage) + '_sitk'],
            moved_im_affine_sitk=pyr[input_regnet_moving],
            fixed_mask_sitk=pyr['fixed_mask_s' + str(stage) + '_sitk'],
            moved_mask_affine_sitk=pyr[input_regnet_moving_mask],
            padto=setting['PadTo']['stage' + str(stage)],
            r_in=r_in,
            r_out=r_out,
            r_out_erode=r_out_erode)

        # building and loading network
        tf.reset_default_graph()
        images_tf = tf.placeholder(
            tf.float32,
            shape=[None, 2 * r_in + 1, 2 * r_in + 1, 2 * r_in + 1, 2],
            name="Images")
        bn_training = tf.placeholder(tf.bool, name='bn_training')
        dvf_tf = getattr(
            getattr(
                network, setting['network_dict']['stage' +
                                                 str(stage)]['NetworkDesign']),
            'network')(images_tf, bn_training)
        logging.debug(' Total number of variables %s' % (np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ])))
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
        sess = tf.Session()
        saver.restore(
            sess,
            su.address_generator(
                setting,
                'saved_model_with_step',
                current_experiment=setting['network_dict'][
                    'stage' + str(stage)]['NetworkLoad'],
                step=setting['network_dict']['stage' +
                                             str(stage)]['GlobalStepLoad']))
        while not pair_pyramid.get_sweep_completed():
            # The pyr[DVF_S] is the numpy DVF which will be filled. the dvf_np is an output
            # patch from the network. We control the spatial location of both dvf in this function
            batch_im, win_center, win_r_before, win_r_after, predicted_begin, predicted_end = pair_pyramid.next_sweep_patch(
            )
            time_before_gpu = time.time()
            [dvf_np] = sess.run([dvf_tf],
                                feed_dict={
                                    images_tf: batch_im,
                                    bn_training: 0
                                })
            time_after_gpu = time.time()
            logging.debug('GPU: ' + pair_info[0]['data'] +
                          ', CN{} center={} is done in {:.2f}s '.format(
                              pair_info[0]['cn'], win_center, time_after_gpu -
                              time_before_gpu))

            pyr['DVF_s'+str(stage)][win_center[0] - win_r_before[0]: win_center[0] + win_r_after[0],
                                    win_center[1] - win_r_before[1]: win_center[1] + win_r_after[1],
                                    win_center[2] - win_r_before[2]: win_center[2] + win_r_after[2], :] = \
                dvf_np[0, predicted_begin[0]:predicted_end[0], predicted_begin[1]:predicted_end[1], predicted_begin[2]:predicted_end[2], :]
            # rescaling dvf based on the voxel spacing:
            spacing_ref = [1.0 * stage for _ in range(3)]
            spacing_current = pyr['fixed_im_s' + str(stage) +
                                  '_sitk'].GetSpacing()
            for dim in range(3):
                pyr['DVF_s' + str(stage)][:, :, :, dim] = pyr['DVF_s' + str(
                    stage)][:, :, :,
                            dim] * spacing_current[dim] / spacing_ref[dim]

        pyr['DVF_s' + str(stage) + '_sitk'] = ip.array_to_sitk(
            pyr['DVF_s' + str(stage)],
            im_ref=pyr['fixed_im_s' + str(stage) + '_sitk'],
            is_vector=True)

        if i_stage == (len(setting['ImagePyramidSchedule']) - 1):
            # when all stages are finished, final dvf and moved image are written
            dvf_composed_final_sitk = 'DVF_s' + str(stage) + '_composed_sitk'
            if len(setting['ImagePyramidSchedule']) == 1:
                # need to upsample in the case that last stage is not 1
                if stage == 1:
                    pyr[dvf_composed_final_sitk] = pyr['DVF_s' + str(stage) +
                                                       '_sitk']
                else:
                    pyr[dvf_composed_final_sitk] = ip.resampler_sitk(
                        pyr['DVF_s' + str(stage) + '_sitk'],
                        scale=1 / stage,
                        im_ref_size=pyr['fixed_im_s1_sitk'].GetSize(),
                        interpolator=sitk.sitkLinear)
            else:
                pyr[dvf_composed_final_sitk] = sitk.Add(
                    pyr['DVF_s' + str(setting['ImagePyramidSchedule'][-2]) +
                        '_composed_up_sitk'],
                    pyr['DVF_s' + str(stage) + '_sitk'])
                if stage != 1:
                    pyr[dvf_composed_final_sitk] = ip.resampler_sitk(
                        pyr[dvf_composed_final_sitk],
                        scale=1 / stage,
                        im_ref_size=pyr['fixed_im_s1_sitk'].GetSize(),
                        interpolator=sitk.sitkLinear)
            if not setting['WriteNoDVF']:
                sitk.WriteImage(
                    sitk.Cast(pyr[dvf_composed_final_sitk],
                              sitk.sitkVectorFloat32),
                    su.address_generator(setting,
                                         'dvf_s0',
                                         pair_info=pair_info,
                                         stage_list=stage_list))
            dvf_t = sitk.DisplacementFieldTransform(
                pyr[dvf_composed_final_sitk])
            pyr['moved_im_s0_sitk'] = ip.resampler_by_transform(
                pyr['moving_im_s1_sitk'],
                dvf_t,
                default_pixel_value=setting['data'][
                    pair_info[1]['data']]['DefaultPixelValue'])
            sitk.WriteImage(
                sitk.Cast(pyr['moved_im_s0_sitk'],
                          setting['data'][pair_info[1]['data']]['ImageByte']),
                moved_im_s0_address)

            if setting['WriteMasksForLSTM']:
                mask_to_zero_stage = setting['network_dict'][
                    'stage' + str(stage)]['MaskToZero']
                if setting['read_pair_mode'] == 'synthetic':
                    moving_mask_sitk = sitk.ReadImage(
                        su.address_generator(setting,
                                             mask_to_zero_stage,
                                             stage=1,
                                             **im_info_su))
                    moved_mask_stage1 = ip.resampler_by_transform(
                        moving_mask_sitk,
                        dvf_t,
                        default_pixel_value=0,
                        interpolator=sitk.sitkNearestNeighbor)
                    sitk.WriteImage(
                        moved_mask_stage1,
                        su.address_generator(setting,
                                             'Moved' + mask_to_zero_stage +
                                             '_AG',
                                             stage=1,
                                             **im_info_su))
                    logging.debug('writing ' +
                                  su.address_generator(setting,
                                                       'Moved' +
                                                       mask_to_zero_stage +
                                                       '_AG',
                                                       stage=1,
                                                       **im_info_su))

    time_after_dvf = time.time()
    logging.debug(
        pair_info[0]['data'] + ', CN{}, ImType{} is done in {:.2f}s '.format(
            pair_info[0]['cn'], pair_info[0]['type_im'], time_after_dvf -
            time_before_dvf))

    return 0
예제 #23
0
def single_freq(setting,
                im_info,
                stage,
                im_input_sitk,
                gonna_generate_next_im=False):
    im_info_su = {
        'data': im_info['data'],
        'deform_exp': im_info['deform_exp'],
        'type_im': im_info['type_im'],
        'cn': im_info['cn'],
        'dsmooth': im_info['dsmooth'],
        'stage': stage,
        'padto': im_info['padto']
    }
    seed_number = ag_utils.seed_number_by_im_info(
        im_info,
        'single_freq',
        stage=stage,
        gonna_generate_next_im=gonna_generate_next_im)
    deform_number = im_info['deform_number']

    if gonna_generate_next_im:
        max_deform = setting['deform_exp'][
            im_info['deform_exp']]['NextIm_MaxDeform']
        dim_im = 3  # The deformation of the NextIm is always 3D
        seed_number = seed_number + 1
        grid_border_to_zero = setting['deform_exp'][
            im_info['deform_exp']]['SingleFrequency_SetGridBorderToZero'][0]
        grid_spacing = setting['deform_exp'][
            im_info['deform_exp']]['SingleFrequency_BSplineGridSpacing'][0]
        grid_smoothing_sigma = [
            i / stage for i in setting['deform_exp'][im_info['deform_exp']]
            ['SingleFrequency_GridSmoothingSigma'][0]
        ]
        bspline_transform_address = su.address_generator(
            setting, 'NextBSplineTransform', **im_info_su)
        bspline_im_address = su.address_generator(setting,
                                                  'NextBSplineTransformIm',
                                                  **im_info_su)
    else:
        max_deform = setting['deform_exp'][im_info['deform_exp']]['MaxDeform'] * \
            setting['deform_exp'][im_info['deform_exp']]['SingleFrequency_MaxDeformRatio'][deform_number]
        dim_im = 3
        grid_border_to_zero = setting['deform_exp'][im_info['deform_exp']][
            'SingleFrequency_SetGridBorderToZero'][deform_number]
        grid_spacing = setting['deform_exp'][im_info['deform_exp']][
            'SingleFrequency_BSplineGridSpacing'][deform_number]
        grid_smoothing_sigma = [
            i / stage for i in setting['deform_exp'][im_info['deform_exp']]
            ['SingleFrequency_GridSmoothingSigma'][deform_number]
        ]
        bspline_transform_address = su.address_generator(
            setting, 'BSplineTransform', **im_info_su)
        bspline_im_address = su.address_generator(setting,
                                                  'BSplineTransformIm',
                                                  **im_info_su)
    random_state = np.random.RandomState(seed_number)

    if setting['DVFPad_S' + str(stage)] > 0:
        # im_input is already zeropadded in this case
        padded_mm = setting['DVFPad_S' +
                            str(stage)] * im_input_sitk.GetSpacing()[0]
        grid_border_to_zero = (grid_border_to_zero + np.ceil(
            np.repeat(padded_mm, int(dim_im)) / grid_spacing)).astype(np.int)
        if len(np.unique(im_input_sitk.GetSpacing())) > 1:
            raise ValueError(
                'dvf_generation: padding is only implemented for isotropic voxel size. current voxel size = [{}, {}, {}]'
                .format(im_input_sitk.GetSpacing()[0],
                        im_input_sitk.GetSpacing()[1],
                        im_input_sitk.GetSpacing()[2]))

    bcoeff = bspline_coeff(im_input_sitk,
                           max_deform,
                           grid_border_to_zero,
                           grid_smoothing_sigma,
                           grid_spacing,
                           random_state,
                           dim_im,
                           artificial_generation='single_frequency')

    if setting['WriteBSplineTransform']:
        sitk.WriteTransform(bcoeff, bspline_transform_address)
        bspline_im_sitk_tuple = bcoeff.GetCoefficientImages()
        bspline_im = np.concatenate(
            (np.expand_dims(sitk.GetArrayFromImage(bspline_im_sitk_tuple[0]),
                            axis=-1),
             np.expand_dims(sitk.GetArrayFromImage(bspline_im_sitk_tuple[1]),
                            axis=-1),
             np.expand_dims(sitk.GetArrayFromImage(bspline_im_sitk_tuple[1]),
                            axis=-1)),
            axis=-1)
        bspline_spacing = bspline_im_sitk_tuple[0].GetSpacing()
        bspling_origin = [
            list(bspline_im_sitk_tuple[0].GetOrigin())[i] +
            list(im_input_sitk.GetOrigin())[i] for i in range(3)
        ]
        bspline_direction = im_input_sitk.GetDirection()
        bspline_im_sitk = ip.array_to_sitk(bspline_im,
                                           origin=bspling_origin,
                                           spacing=bspline_spacing,
                                           direction=bspline_direction,
                                           is_vector=True)
        sitk.WriteImage(bspline_im_sitk, bspline_im_address)

    dvf_filter = sitk.TransformToDisplacementFieldFilter()
    dvf_filter.SetSize(im_input_sitk.GetSize())
    dvf_sitk = dvf_filter.Execute(bcoeff)
    dvf = sitk.GetArrayFromImage(dvf_sitk)

    mask_to_zero = setting['deform_exp'][im_info['deform_exp']]['MaskToZero']
    if mask_to_zero is not None and not gonna_generate_next_im:
        sigma = setting['deform_exp'][im_info['deform_exp']][
            'SingleFrequency_BackgroundSmoothingSigma'][deform_number]
        dvf = do_mask_to_zero_gaussian(setting, im_info_su, dvf, mask_to_zero,
                                       stage, max_deform, sigma)

    if setting['deform_exp'][im_info['deform_exp']]['DVFNormalization']:
        dvf = normalize_dvf(dvf, max_deform)
    return dvf
예제 #24
0
def check_downsampled_base_reg(setting,
                               stage,
                               base_reg=None,
                               pair_info=None,
                               mask_to_zero_stage=None):
    if 'DownSamplingByGPU' not in setting.keys():
        setting['DownSamplingByGPU'] = False
    im_info_moving = pair_info[1]
    im_list_downsample = [
        {
            'Image':
            'MovedImBaseReg',
            'interpolator':
            'BSpline',
            'DefaultPixelValue':
            setting['data'][im_info_moving['data']]['DefaultPixelValue'],
            'ImageByte':
            setting['data'][im_info_moving['data']]['ImageByte']
        },
        {
            'Image': 'Moved' + mask_to_zero_stage + 'BaseReg',
            'interpolator': 'NearestNeighbor',
            'DefaultPixelValue': 0,
            'ImageByte': sitk.sitkInt8
        },
    ]
    #
    # im_stage_address = su.address_generator(setting, 'MovedImBaseReg', pair_info=pair_info,
    #                          stage=stage, base_reg=base_reg, **im_info_moving)
    # mask_stage_address = su.address_generator(setting, 'Moved' + mask_to_zero_stage + 'BaseReg', pair_info=pair_info,
    #                          stage=stage, base_reg=base_reg, **im_info_moving)

    for im_dict in im_list_downsample:
        im_stage_address = su.address_generator(setting,
                                                im_dict['Image'],
                                                stage=stage,
                                                pair_info=pair_info,
                                                base_reg=base_reg,
                                                **im_info_moving)

        if not os.path.isfile(im_stage_address):
            im_s1_sitk = sitk.ReadImage(
                su.address_generator(setting,
                                     im_dict['Image'],
                                     stage=1,
                                     pair_info=pair_info,
                                     base_reg=base_reg,
                                     **im_info_moving))
            if setting['DownSamplingByGPU'] and im_dict[
                    'Image'] == 'MovedImBaseReg':
                im_s1 = sitk.GetArrayFromImage(im_s1_sitk)
                im_stage = ip.downsampler_gpu(
                    im_s1,
                    stage,
                    normalize_kernel=True,
                    default_pixel_value=im_dict['DefaultPixelValue'])

                im_stage_sitk = ip.array_to_sitk(
                    im_stage,
                    origin=im_s1_sitk.GetOrigin(),
                    spacing=tuple(i * stage for i in im_s1_sitk.GetSpacing()),
                    direction=im_s1_sitk.GetDirection())
            else:
                if im_dict['interpolator'] == 'NearestNeighbor':
                    interpolator = sitk.sitkNearestNeighbor
                elif im_dict['interpolator'] == 'BSpline':
                    interpolator = sitk.sitkBSpline
                else:
                    raise ValueError(
                        "interpolator should be in ['NearestNeighbor', 'BSpline']"
                    )

                if im_dict['Image'] in [
                        'MovedTorsoBaseReg', 'MovedLungBaseReg'
                ]:
                    im_ref_sitk = sitk.ReadImage(
                        su.address_generator(setting,
                                             'MovedImBaseReg',
                                             pair_info=pair_info,
                                             stage=stage,
                                             base_reg=base_reg,
                                             **im_info_moving))
                else:
                    im_ref_sitk = None
                im_stage_sitk = ip.resampler_sitk(
                    im_s1_sitk,
                    scale=stage,
                    im_ref=im_ref_sitk,
                    default_pixel_value=im_dict['DefaultPixelValue'],
                    interpolator=interpolator)

            sitk.WriteImage(sitk.Cast(im_stage_sitk, im_dict['ImageByte']),
                            im_stage_address)
    return 0
예제 #25
0
def multi_stage_error(setting, landmark_dict, pair_info, overwrite=False):
    """
    :param setting:
    :param pair_info: information of the pair to be registered.
    :param overwrite:
    :return: The output moved images and dvf will be written to the disk.
             1: registration is performed correctly
             2: skip overwriting
             3: the dvf is available from the previous experiment [4, 2, 1]. Then just upsample it.
    """
    stage_list = setting['ImagePyramidSchedule']
    if setting['CNN_Mode']:
        lstm_mode = False
    else:
        lstm_mode = True
    dvf_error_address = su.address_generator(
        setting,
        'dvf_error',
        current_experiment=setting['lstm_exp'],
        pair_info=pair_info,
        stage_list=stage_list,
        base_reg=setting['BaseReg'])
    if os.path.isfile(dvf_error_address):
        if not overwrite:
            logging.debug('overwrite=False, file ' + dvf_error_address +
                          ' already exists, skipping .....')
            return 2
        else:
            logging.debug('overwrite=True, file ' + dvf_error_address +
                          ' already exists, but overwriting .....')

    im_dict = dict()
    im_info_fixed = pair_info[0]
    im_info_moving = pair_info[1]
    base_reg = copy.copy(setting['BaseReg'])
    # im_info_moving['base_reg'] = copy.copy(base_reg)

    network_lstm_design = setting['network_lstm_dict']['NetworkDesign']
    setting['stage'] = 1
    setting = su.load_network_setting(setting,
                                      network_name=network_lstm_design)
    # setting['ImPad_S1'] = 50  # more safe selection
    if im_info_fixed['data'] == 'DIR-Lab_COPD':
        if setting['ImPad_S1'] < 42:
            setting['ImPad_S1'] = 42

    if im_info_fixed['data'] == 'DIR-Lab_4D':
        if setting['ImPad_S1'] < 40:
            setting['ImPad_S1'] = 40

    setting['ImPad_S2'] = setting['ImPad_S1'] + 7  # 5
    setting['ImPad_S4'] = setting['ImPad_S1'] + 10  # 8

    base_key_list = [
        'fixed_im_s', 'fixed_mask_s', 'moved_im_s', 'moved_mask_s'
    ]
    for i_stage, stage in enumerate(stage_list):
        mask_to_zero_stage = setting['network_dict']['stage' +
                                                     str(stage)]['MaskToZero']
        fixed_im_stage_address = su.address_generator(setting,
                                                      'Im',
                                                      stage=stage,
                                                      **im_info_fixed)
        fixed_mask_stage_address = su.address_generator(setting,
                                                        mask_to_zero_stage,
                                                        stage=stage,
                                                        **im_info_fixed)
        im_dict['fixed_im_s' + str(stage)] = sitk.GetArrayFromImage(
            sitk.ReadImage(fixed_im_stage_address))
        im_dict['fixed_mask_s' + str(stage)] = sitk.GetArrayFromImage(
            sitk.ReadImage(fixed_mask_stage_address))

        if i_stage == 0:
            check_downsampled_base_reg(setting,
                                       stage,
                                       base_reg=base_reg,
                                       pair_info=pair_info,
                                       mask_to_zero_stage=mask_to_zero_stage)
            moved_im_stage_address = su.address_generator(setting,
                                                          'MovedImBaseReg',
                                                          pair_info=pair_info,
                                                          stage=stage,
                                                          base_reg=base_reg,
                                                          **im_info_moving)
            moved_mask_stage_address = su.address_generator(
                setting,
                'Moved' + mask_to_zero_stage + 'BaseReg',
                pair_info=pair_info,
                stage=stage,
                base_reg=base_reg,
                **im_info_moving)
        else:
            if setting['UseRegisteredImages']:
                moved_im_stage_address = su.address_generator(
                    setting,
                    'MovedIm',
                    current_experiment=setting['exp_multi_reg'],
                    pair_info=pair_info,
                    stage=stage,
                    stage_list=stage_list,
                    base_reg=base_reg)
                moved_mask_stage_address = su.address_generator(
                    setting,
                    'Moved' + mask_to_zero_stage,
                    current_experiment=setting['exp_multi_reg'],
                    pair_info=pair_info,
                    stage=stage,
                    stage_list=stage_list,
                    base_reg=base_reg)

            else:
                check_downsampled_base_reg(
                    setting,
                    stage,
                    base_reg=base_reg,
                    pair_info=pair_info,
                    mask_to_zero_stage=mask_to_zero_stage)
                moved_im_stage_address = su.address_generator(
                    setting,
                    'MovedImBaseReg',
                    pair_info=pair_info,
                    stage=stage,
                    base_reg=base_reg,
                    **im_info_moving)
                moved_mask_stage_address = su.address_generator(
                    setting,
                    'Moved' + mask_to_zero_stage + 'BaseReg',
                    pair_info=pair_info,
                    stage=stage,
                    base_reg=base_reg,
                    **im_info_moving)

        im_dict['moved_im_s' + str(stage)] = sitk.GetArrayFromImage(
            sitk.ReadImage(moved_im_stage_address))
        im_dict['moved_mask_s' + str(stage)] = sitk.GetArrayFromImage(
            sitk.ReadImage(moved_mask_stage_address))
        logging.info(fixed_im_stage_address + ' loaded')
        logging.info(fixed_mask_stage_address + ' loaded')
        logging.info(moved_im_stage_address + ' loaded')
        logging.info(moved_mask_stage_address + ' loaded')

        default_pixel = setting['data'][pair_info[0]
                                        ['data']]['DefaultPixelValue']
        im_dict['fixed_im_s' +
                str(stage)][im_dict['fixed_mask_s' +
                                    str(stage)] == 0] = default_pixel
        im_dict['moved_im_s' +
                str(stage)][im_dict['moved_mask_s' +
                                    str(stage)] == 0] = default_pixel

        if setting['ImPad_S' + str(stage)] > 0:
            for base_key in base_key_list:
                im_dict[base_key + str(stage)] = np.pad(
                    im_dict[base_key + str(stage)],
                    setting['ImPad_S' + str(stage)],
                    'constant',
                    constant_values=(default_pixel, ))

    indices_padded = copy.deepcopy(landmark_dict['FixedLandmarksIndex'])
    indices_padded = indices_padded + setting['ImPad_S1']

    batch_size = 15
    for i in range(3):
        landmark_dict['DVF_error_times' + str(i) + '_logits'] = np.zeros(
            (np.shape(landmark_dict['FixedLandmarksIndex'])[0], ) +
            (setting['NumberOfLabels'], ))
        landmark_dict['DVF_error_times' + str(i) + '_label'] = np.zeros(
            np.shape(landmark_dict['FixedLandmarksIndex'])[0])
    multi_stage_network_design = load_network_multi_stage_from_predefined(
        setting['exp_multi_reg'])
    multi_stage_network_load = get_parameter_multi_stage_network(
        setting, multi_stage_network_design)
    multi_stage_network_address = dict()
    for stage in setting['ImagePyramidSchedule']:
        multi_stage_network_address[
            'stage' + str(stage)] = su.address_generator(
                setting,
                'saved_model_with_step',
                current_experiment=multi_stage_network_load[
                    'stage' + str(stage)]['NetworkLoad'],
                step=multi_stage_network_load['stage' +
                                              str(stage)]['GlobalStepLoad'])
    # Network
    tf.reset_default_graph()
    tf.set_random_seed(0)
    with tf.variable_scope('InputImages'):
        images_s1_tf = tf.placeholder(tf.float32,
                                      shape=[
                                          batch_size, 2 * setting['R'] + 1,
                                          2 * setting['R'] + 1,
                                          2 * setting['R'] + 1, 2
                                      ],
                                      name="Images_S1")
        images_s2_tf = tf.placeholder(tf.float32,
                                      shape=[
                                          batch_size, 2 * setting['R'] + 1,
                                          2 * setting['R'] + 1,
                                          2 * setting['R'] + 1, 2
                                      ],
                                      name="Images_S2")
        images_s4_tf = tf.placeholder(tf.float32,
                                      shape=[
                                          batch_size, 2 * setting['R'] + 1,
                                          2 * setting['R'] + 1,
                                          2 * setting['R'] + 1, 2
                                      ],
                                      name="Images_S4")
    bn_training = tf.placeholder(tf.bool, name='bn_training')

    if lstm_mode:
        out0_tf, out1_tf, out2_tf, state0_tf, state1_tf = getattr(
            getattr(network, network_lstm_design),
            'network')(images_s1_tf,
                       images_s2_tf,
                       images_s4_tf,
                       bn_training,
                       detailed_summary=False,
                       use_keras=setting['use_keras'],
                       num_of_classes=setting['NumberOfLabels'],
                       multi_stage_network_address=multi_stage_network_address)
    else:
        out2_tf = getattr(getattr(network, network_lstm_design), 'network')(
            images_s1_tf,
            images_s2_tf,
            images_s4_tf,
            bn_training,
            detailed_summary=setting['DetailedNetworkSummary'],
            use_keras=setting['use_keras'],
            num_of_classes=setting['NumberOfLabels'],
            multi_stage_network_address=multi_stage_network_address)
        out0_tf, out1_tf, state0_tf, state1_tf = None, None, None, None

    sess = tf.Session()

    saver_loading = tf.train.Saver(
        tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
    saver_loading.restore(
        sess,
        su.address_generator(
            setting,
            'saved_model_with_step',
            current_experiment=setting['network_lstm_dict']['NetworkLoad'],
            step=setting['network_lstm_dict']['GlobalStepLoad']))

    time_before_dvf = time.time()
    begin_landmark_i = 0
    end_landmark_i = 0
    fill_last_batch = 0
    while end_landmark_i < np.shape(
            landmark_dict['DVF_error_times2_logits'])[0]:
        end_landmark_i = begin_landmark_i + batch_size
        if end_landmark_i > np.shape(
                landmark_dict['DVF_error_times2_logits'])[0]:
            fill_last_batch = end_landmark_i - np.shape(
                landmark_dict['DVF_error_times2_logits'])[0]
            end_landmark_i = np.shape(
                landmark_dict['DVF_error_times2_logits'])[0]
        logging.info('begin_landmark_i={}, end_landmark_i={}'.format(
            begin_landmark_i, end_landmark_i))
        batch_im = extract_batch_from_patch_seq(
            setting, im_dict,
            indices_padded[begin_landmark_i:end_landmark_i, :])
        if fill_last_batch:
            batch_im = fill_junk_to_batch_size(setting, batch_size, batch_im)
        out_np = [None for _ in range(3)]
        if lstm_mode:
            [out_np[0], out_np[1], out_np[2]] = sess.run(
                [out0_tf, out1_tf, out2_tf],
                feed_dict={
                    images_s1_tf: batch_im['stage1'],
                    images_s2_tf: batch_im['stage2'],
                    images_s4_tf: batch_im['stage4'],
                    bn_training: 0
                })
        else:
            [out_np[2]] = sess.run(
                [out2_tf],
                feed_dict={
                    images_s1_tf: batch_im['stage1'],
                    images_s2_tf: batch_im['stage2'],
                    images_s4_tf: batch_im['stage4'],
                    bn_training: 0
                })

        for i in range(3):
            if lstm_mode or i == 2:
                out_np_center = out_np[i][:, setting['Ry'] + 1,
                                          setting['Ry'] + 1,
                                          setting['Ry'] + 1, :]
                out_np_center_label = np.argmax(
                    out_np_center[:, setting['Labels_time' + str(i)]], axis=1)
                landmark_dict['DVF_error_times' + str(i) + '_label'][
                    begin_landmark_i:end_landmark_i] = out_np_center_label[0:(
                        batch_size - fill_last_batch)]
                landmark_dict['DVF_error_times' + str(i) + '_logits'][
                    begin_landmark_i:end_landmark_i, :] = out_np_center[0:(
                        batch_size - fill_last_batch), :]
        begin_landmark_i += 15

    time_after_dvf = time.time()
    logging.debug(
        pair_info[0]['data'] + ', CN{}, ImType{} is done in {:.2f}s '.format(
            pair_info[0]['cn'], pair_info[0]['type_im'], time_after_dvf -
            time_before_dvf))

    landmark_dict['DVF_nonrigidGroundTruth_magnitude'] = np.sqrt(
        landmark_dict['DVF_nonrigidGroundTruth'][:, 0]**2 +
        landmark_dict['DVF_nonrigidGroundTruth'][:, 1]**2 +
        landmark_dict['DVF_nonrigidGroundTruth'][:, 2]**2)

    result_landmarks_folder = su.address_generator(
        setting,
        'result_landmarks_folder',
        stage_list=stage_list,
        current_experiment=setting['lstm_exp'])
    if not os.path.isdir(result_landmarks_folder):
        os.makedirs(result_landmarks_folder)

    pair_info_text = setting['BaseReg'] + '_' +  pair_info[0]['data'] + \
                     '_CN{}_TypeIm{},'.format(pair_info[0]['cn'], pair_info[0]['type_im']) + '  Moving:' + \
                     pair_info[1]['data'] + '_CN{}_TypeIm{}'.format(pair_info[1]['cn'], pair_info[1]['type_im'])

    # import matplotlib.pyplot as plt
    # plt.figure()
    # plt.plot(landmark_dict['DVF_error_times2_label'], landmark_dict['DVF_nonrigidGroundTruth_magnitude'], 'o')
    # plt.draw()
    # plt.savefig(result_landmarks_folder+pair_info_text+'.png')
    # plt.close()

    return landmark_dict
예제 #26
0
    def fill(self):
        self._filled = 0
        number_of_images_per_chunk = self._number_of_images_per_chunk
        if self._train_mode == 'Training':
            # Make all lists empty in training mode. In the validation or test mode, we keep the same chunk forever. So no need to make it empty and refill it again.
            self._fixed_im_list = self.empty_sequence(
                number_of_images_per_chunk)
            self._moved_im_list = self.empty_sequence(
                number_of_images_per_chunk)
            self._dvf_list = [None] * number_of_images_per_chunk

        if self._semi_epochs_completed:
            self._semi_epoch = self._semi_epoch + 1
            self._semi_epochs_completed = 0
            self._chunk = 0

        im_info_list_full = copy.deepcopy(self._im_info_list_full)
        random_state = np.random.RandomState(self._semi_epoch)
        if self._setting['Randomness']:
            random_indices = random_state.permutation(len(im_info_list_full))
        else:
            random_indices = np.arange(len(im_info_list_full))

        lower_range = self._chunk * number_of_images_per_chunk
        upper_range = (self._chunk + 1) * number_of_images_per_chunk
        if upper_range >= len(im_info_list_full):
            upper_range = len(im_info_list_full)
            self._semi_epochs_completed = 1
            number_of_images_per_chunk = upper_range - lower_range  # In cases when last chunk of images are smaller than the self._numberOfImagesPerChunk
            self._fixed_im_list = self.empty_sequence(
                number_of_images_per_chunk)
            self._moved_im_list = self.empty_sequence(
                number_of_images_per_chunk)
            self._dvf_list = [None] * number_of_images_per_chunk

        log_msg = self._class_mode + ': stage={}, SemiEpoch={}, Chunk={} '.format(
            self._stage, self._semi_epoch, self._chunk)
        logging.debug(log_msg)
        if self._class_mode == 'Thread':
            with open(su.address_generator(self._setting, 'log_im_file'),
                      'a+') as f:
                f.write(log_msg + '\n')

        torso_list = [None] * len(self._dvf_list)
        indices_chunk = random_indices[lower_range:upper_range]
        im_info_list = [im_info_list_full[i] for i in indices_chunk]
        for i_index_im, index_im in enumerate(indices_chunk):
            self._fixed_im_list[i_index_im], self._moved_im_list[i_index_im], self._dvf_list[i_index_im], torso_list[i_index_im] = \
                ag.get_dvf_and_deformed_images_seq(self._setting,
                                                   im_info=im_info_list_full[index_im],
                                                   stage_sequence=self._stage_sequence,
                                                   mode_synthetic_dvf=self._mode_artificial_generation
                                                   )
            if self._class_mode == '1stEpoch':
                self._fixed_im_list[i_index_im] = None
                self._moved_im_list[i_index_im] = None
            if self._setting['verbose']:
                log_msg = self._class_mode+': Data='+im_info_list_full[index_im]['data'] +\
                          ', TypeIm={}, CN={}, Dsmooth={}, stage={} is loaded'.format(im_info_list_full[index_im]['type_im'], im_info_list_full[index_im]['cn'],
                                                                                      im_info_list_full[index_im]['dsmooth'], self._stage)
                logging.debug(log_msg)
                if self._class_mode == 'Thread':
                    with open(
                            su.address_generator(self._setting, 'log_im_file'),
                            'a+') as f:
                        f.write(log_msg + '\n')

        ishuffled_folder = reading_utils.get_ishuffled_folder_write_ishuffled_setting(
            self._setting,
            self._train_mode,
            self._stage,
            self._number_of_images_per_chunk,
            self._samples_per_image,
            self._im_info_list_full,
            full_image=self._full_image,
            chunk_length_force_to_multiple_of=self.
            _chunk_length_force_to_multiple_of)
        ishuffled_name = su.address_generator(self._setting,
                                              'IShuffledName',
                                              semi_epoch=self._semi_epoch,
                                              chunk=self._chunk)
        ishuffled_address = ishuffled_folder + ishuffled_name

        if self._mode_artificial_generation == 'reading' and not (
                self._class_mode == 'Thread' and self._semi_epoch > 0
        ) and not self._setting['never_generate_image']:
            count_wait = 1
            while not os.path.isfile(ishuffled_address):
                time.sleep(5)
                logging.debug(
                    self._class_mode +
                    ': waiting {} s for IShuffled:'.format(count_wait * 5) +
                    ishuffled_address)
                count_wait += 1
            self._ishuffled = np.load(ishuffled_address)

        if os.path.isfile(ishuffled_address):
            self._ishuffled = np.load(ishuffled_address)
            log_msg = self._class_mode + ': loading IShuffled: ' + ishuffled_address
        else:
            log_msg = self._class_mode + ': generating IShuffled: ' + ishuffled_address
            self._ishuffled = reading_utils.shuffled_indices_from_chunk(
                self._setting,
                dvf_list=self._dvf_list,
                torso_list=torso_list,
                im_info_list=im_info_list,
                stage_sequence=self._stage_sequence,
                semi_epoch=self._semi_epoch,
                chunk=self._chunk,
                samples_per_image=self._samples_per_image,
                log_header=self._class_mode,
                full_image=self._full_image,
                seq_mode=True,
                chunk_length_force_to_multiple_of=self.
                _chunk_length_force_to_multiple_of)
            np.save(ishuffled_address, self._ishuffled)
            logging.debug(self._class_mode + ': saving IShuffled: ' +
                          ishuffled_address)

        logging.debug(log_msg)
        if self._class_mode == 'Thread':
            with open(su.address_generator(self._setting, 'log_im_file'),
                      'a+') as f:
                f.write(log_msg + '\n')
        if self._class_mode == 'Thread':
            if not self._full_image:
                class_balanced = self._setting['ClassBalanced']
                hist_class = np.zeros(len(class_balanced), dtype=np.int32)
                hist_text = ''
                for c in range(len(class_balanced)):
                    hist_class[c] = sum(self._ishuffled[:, 2] == c)
                    hist_text = hist_text + 'Class' + str(c) + ': ' + str(
                        hist_class[c]) + ', '
                log_msg = hist_text + self._class_mode + ': stage={}, SemiEpoch={}, Chunk={} '.format(
                    self._stage, self._semi_epoch, self._chunk)
                with open(su.address_generator(self._setting, 'log_im_file'),
                          'a+') as f:
                    f.write(log_msg + '\n')

            with open(su.address_generator(self._setting, 'log_im_file'),
                      'a+') as f:
                f.write('========================' + '\n')
            logging.debug('Thread is filled .....will be paused')
            self._filled = 1
            self.pause()
예제 #27
0
def add_occlusion(setting,
                  im_info,
                  stage,
                  deformed_im_previous_sitk=None,
                  dvf_sitk=None):
    im_info_su = {
        'data': im_info['data'],
        'deform_exp': im_info['deform_exp'],
        'type_im': im_info['type_im'],
        'cn': im_info['cn'],
        'dsmooth': im_info['dsmooth'],
        'stage': 1,
        'padto': im_info['padto']
    }
    seed_number = ag_utils.seed_number_by_im_info(im_info,
                                                  'add_occlusion',
                                                  stage=stage)
    random_state = np.random.RandomState(seed_number)

    if deformed_im_previous_sitk is None:
        deformed_im_previous_sitk = sitk.ReadImage(
            su.address_generator(setting,
                                 'DeformedIm',
                                 deformed_im_ext='Noise',
                                 **im_info_su))

    if dvf_sitk is None:
        dvf_address = su.address_generator(setting, 'DeformedDVF',
                                           **im_info_su)
        dvf_sitk = sitk.ReadImage(dvf_address)

    deformed_lung_address = su.address_generator(setting, 'DeformedLung',
                                                 **im_info_su)
    if not os.path.isfile(deformed_lung_address):
        im_lung_sitk = sitk.ReadImage(
            su.address_generator(setting, 'Lung', **im_info_su))
        dvf_t = sitk.DisplacementFieldTransform(
            sitk.Cast(dvf_sitk, sitk.sitkVectorFloat64))
        deformed_lung_sitk = ip.resampler_by_transform(
            im_lung_sitk,
            dvf_t,
            im_ref=deformed_im_previous_sitk,
            default_pixel_value=0,
            interpolator=sitk.sitkNearestNeighbor)
        sitk.WriteImage(sitk.Cast(deformed_lung_sitk, sitk.sitkInt8),
                        deformed_lung_address)
        time.sleep(5)

    deformed_im_noise = sitk.GetArrayFromImage(deformed_im_previous_sitk)
    deformed_lung = sitk.GetArrayFromImage(
        sitk.ReadImage(deformed_lung_address))
    struct = np.ones((9, 9, 9), dtype=bool)
    deformed_lung_erode = ndimage.binary_erosion(deformed_lung,
                                                 structure=struct).astype(
                                                     np.bool)
    ellipse_lung = np.zeros(deformed_im_noise.shape, dtype=np.bool)
    ellipse_center_lung = deformed_lung_erode.copy()

    for ellipse_number in range(setting['deform_exp'][im_info['deform_exp']]
                                ['Occlusion_NumberOfEllipse']):
        center_list = np.where(ellipse_center_lung > 0)
        selected_center_i = int(
            random_state.randint(0, len(center_list[0]) - 1, 1,
                                 dtype=np.int64))
        a = random_state.random_sample() * setting['deform_exp'][
            im_info['deform_exp']]['Occlusion_Max_a']
        b = random_state.random_sample() * setting['deform_exp'][
            im_info['deform_exp']]['Occlusion_Max_b']
        c = random_state.random_sample() * setting['deform_exp'][
            im_info['deform_exp']]['Occlusion_Max_c']
        if a < 3:
            a = 3
        if b < 3:
            b = 3
        if c < 3:
            c = 3
        ellipse_crop = np.zeros(
            [2 * round(a) + 1, 2 * round(b) + 1, 2 * round(c) + 1])
        for i1 in range(np.shape(ellipse_crop)[0]):
            for i2 in range(np.shape(ellipse_crop)[1]):
                for i3 in range(np.shape(ellipse_crop)[2]):
                    if (((i1 - a) / a)**2 + ((i2 - b) / b)**2 +
                        ((i3 - c) / c)**2) < 1:
                        ellipse_crop[i1, i2, i3] = 1
        ellipse_lung[center_list[0][selected_center_i]-round(a/2): center_list[0][selected_center_i]-round(a/2)+np.shape(ellipse_crop)[0],
                     center_list[1][selected_center_i]-round(b/2): center_list[1][selected_center_i]-round(b/2)+np.shape(ellipse_crop)[1],
                     center_list[2][selected_center_i]-round(c/2): center_list[2][selected_center_i]-round(c/2)+np.shape(ellipse_crop)[2]] = \
            ellipse_crop

        margin = 5
        ellipse_center_lung[center_list[0][selected_center_i] - round(a / 2) -
                            margin:center_list[0][selected_center_i] -
                            round(a / 2) + np.shape(ellipse_crop)[0] + margin,
                            center_list[1][selected_center_i] - round(b / 2) -
                            margin:center_list[1][selected_center_i] -
                            round(b / 2) + np.shape(ellipse_crop)[1] + margin,
                            center_list[2][selected_center_i] - round(c / 2) -
                            margin:center_list[2][selected_center_i] -
                            round(c / 2) + np.shape(ellipse_crop)[2] +
                            margin] = 0

    sitk.WriteImage(
        sitk.Cast(
            ip.array_to_sitk(ellipse_lung.astype(np.int8),
                             im_ref=deformed_im_previous_sitk), sitk.sitkInt8),
        su.address_generator(setting, 'DeformedLungOccluded', **im_info_su))

    ellipse_lung_erode = (ellipse_lung.copy()).astype(np.bool)
    struct = np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]],
                       [[0, 1, 0], [1, 1, 1], [0, 1, 0]],
                       [[0, 0, 0], [0, 1, 0], [0, 0, 0]]])
    weight_occluded_list = np.array([0.2, 0.6, 0.9])
    weight_image_list = 1 - weight_occluded_list
    print(
        '--------------------------------------will be corrected: occlusion intensity is not always an integer'
    )
    print(
        '--------------------------------------will be corrected: occlusion intensity is not always an integer'
    )
    print(
        '--------------------------------------will be corrected: occlusion intensity is not always an integer'
    )
    occlusion_intensity = int(
        random_state.randint(setting['deform_exp'][
            im_info['deform_exp']]['Occlusion_IntensityRange'][0],
                             setting['deform_exp'][im_info['deform_exp']]
                             ['Occlusion_IntensityRange'][1],
                             1,
                             dtype=np.int64))

    for i in range(3):
        ellipse_lung_erode_new = ndimage.binary_erosion(
            ellipse_lung_erode, structure=struct).astype(np.bool)
        edge_lung = ellipse_lung_erode ^ ellipse_lung_erode_new
        i_edge = np.where(edge_lung)
        deformed_im_noise[
            i_edge] = deformed_im_noise[i_edge] * weight_image_list[
                i] + occlusion_intensity * weight_occluded_list[i]
        ellipse_lung_erode = ellipse_lung_erode_new.copy()

    i_inside = np.where(ellipse_lung_erode_new > 0)
    deformed_im_noise[i_inside] = occlusion_intensity + random_state.normal(
        scale=setting['deform_exp'][im_info['deform_exp']]['Im_NoiseSigma'],
        size=len(i_inside[0]))
    deformed_im_occluded_sitk = ip.array_to_sitk(
        deformed_im_noise, im_ref=deformed_im_previous_sitk)

    return deformed_im_occluded_sitk
예제 #28
0
        self.ind = self.slices // 2

        self.im = ax.imshow(self.X[self.ind, :, :], cmap=cmap, aspect=aspect)
        self.update()

    def onscroll(self, event):
        print("%s %s" % (event.button, event.step))
        if event.button == 'up':
            self.ind = (self.ind + 1) % self.slices
        else:
            self.ind = (self.ind - 1) % self.slices
        self.update()

    def update(self):
        self.im.set_data(self.X[self.ind, :, :])
        self.ax.set_ylabel('slice %s' % self.ind)
        self.im.axes.figure.canvas.draw()


if __name__ == '__main__':
    im = np.random.rand(20, 20, 40)
    current_experiment = ''
    setting = su.initialize_setting(current_experiment)
    setting['deformName'] = '3D_max15_D9'
    setting = su.load_setting_from_data_dict(setting)
    im_sitk = sitk.ReadImage(
        su.address_generator(setting, 'Im', type_im=0, cn=1))
    im = sitk.GetArrayFromImage(im_sitk)
    view3d_image(im, slice_axis=1)
    hi = 1
예제 #29
0
def add_noise(setting,
              im_info,
              stage,
              deformed_im_previous_sitk=None,
              deformed_torso_sitk=None,
              gonna_generate_next_im=False):
    im_info_su = {
        'data': im_info['data'],
        'deform_exp': im_info['deform_exp'],
        'type_im': im_info['type_im'],
        'cn': im_info['cn'],
        'dsmooth': im_info['dsmooth'],
        'stage': stage,
        'padto': im_info['padto']
    }
    seed_number = ag_utils.seed_number_by_im_info(
        im_info,
        'add_noise',
        stage=stage,
        gonna_generate_next_im=gonna_generate_next_im)
    random_state = np.random.RandomState(seed_number)

    if gonna_generate_next_im:
        sigma_noise = setting['deform_exp'][
            im_info['deform_exp']]['NextIm_SigmaN']
        torso_address = su.address_generator(setting, 'NextTorso',
                                             **im_info_su)
    else:
        sigma_noise = setting['deform_exp'][
            im_info['deform_exp']]['Im_NoiseSigma']
        torso_address = su.address_generator(setting, 'DeformedTorso',
                                             **im_info_su)

    if deformed_im_previous_sitk is None:
        deformed_im_previous_sitk = sitk.ReadImage(
            su.address_generator(setting,
                                 'DeformedIm',
                                 deformed_im_ext='Clean',
                                 **im_info_su))

    max_mean_noise = setting['deform_exp'][
        im_info['deform_exp']]['Im_NoiseAverage']
    random_mean = random_state.uniform(-max_mean_noise, max_mean_noise)
    if setting['data'][im_info['data']]['ImageByte'] in [
            sitk.sitkUInt8, sitk.sitkUInt16, sitk.sitkUInt32, sitk.sitkUInt64,
            sitk.sitkInt8, sitk.sitkInt16, sitk.sitkInt32, sitk.sitkInt64
    ]:
        random_mean = int(random_mean)

    deformed_im_noise_sitk = sitk.AdditiveGaussianNoise(
        deformed_im_previous_sitk, sigma_noise, random_mean, 0)
    if setting['UseTorsoMask']:
        # no noise outside of Torso region.
        if deformed_torso_sitk is None:
            deformed_torso_sitk = sitk.ReadImage(torso_address)
        deformed_torso = sitk.GetArrayFromImage(deformed_torso_sitk)
        deformed_im_noise = sitk.GetArrayFromImage(deformed_im_noise_sitk)
        deformed_im_previous = sitk.GetArrayFromImage(
            deformed_im_previous_sitk)
        deformed_im_noise[deformed_torso == 0] = deformed_im_previous[
            deformed_torso == 0]
        deformed_im_noise_sitk = ip.array_to_sitk(
            deformed_im_noise, im_ref=deformed_im_previous_sitk)
    return deformed_im_noise_sitk
예제 #30
0
def mixed_freq(setting, im_info, stage):
    im_info_su = {
        'data': im_info['data'],
        'deform_exp': im_info['deform_exp'],
        'type_im': im_info['type_im'],
        'cn': im_info['cn'],
        'dsmooth': im_info['dsmooth'],
        'stage': stage,
        'padto': im_info['padto']
    }
    seed_number = ag_utils.seed_number_by_im_info(im_info,
                                                  'mixed_freq',
                                                  stage=stage)
    random_state = np.random.RandomState(seed_number)
    deform_number = im_info['deform_number']
    max_deform = setting['deform_exp'][im_info['deform_exp']]['MaxDeform'] * \
        setting['deform_exp'][im_info['deform_exp']]['MixedFrequency_MaxDeformRatio'][deform_number]
    grid_smoothing_sigma = [
        i / stage for i in setting['deform_exp'][im_info['deform_exp']]
        ['MixedFrequency_GridSmoothingSigma'][deform_number]
    ]
    grid_border_to_zero = setting['deform_exp'][im_info['deform_exp']][
        'MixedFrequency_SetGridBorderToZero'][deform_number]
    grid_spacing = setting['deform_exp'][im_info['deform_exp']][
        'MixedFrequency_BSplineGridSpacing'][deform_number]  # Approximately
    number_dilation = setting['deform_exp'][
        im_info['deform_exp']]['MixedFrequency_Np'][deform_number]

    im_canny_address = su.address_generator(setting, 'ImCanny', **im_info_su)
    im_sitk = sitk.ReadImage(su.address_generator(setting, 'Im', **im_info_su))
    if os.path.isfile(im_canny_address):
        im_canny_sitk = sitk.ReadImage(im_canny_address)
    else:
        im_canny_sitk = sitk.CannyEdgeDetection(
            sitk.Cast(im_sitk, sitk.sitkFloat32),
            lowerThreshold=setting['deform_exp'][
                im_info['deform_exp']]['Canny_LowerThreshold'],
            upperThreshold=setting['deform_exp'][
                im_info['deform_exp']]['Canny_UpperThreshold'])
        sitk.WriteImage(sitk.Cast(im_canny_sitk, sitk.sitkInt8),
                        im_canny_address)
    lung_im = sitk.GetArrayFromImage(
        sitk.ReadImage(su.address_generator(setting, 'Lung',
                                            **im_info_su))).astype(np.bool)
    im_canny = sitk.GetArrayFromImage(im_canny_sitk)
    # erosion with ndimage is 5 times faster than SimpleITK
    lung_dilated = ndimage.binary_dilation(lung_im)
    available_region = np.logical_and(lung_dilated, im_canny)
    available_region = np.tile(np.expand_dims(available_region, axis=-1), 3)
    dilated_edge = np.copy(available_region)

    itr_edge = 0
    i_edge = [None] * 3
    select_voxel = [None] * 3
    block_low = [None] * 3
    block_high = [None] * 3
    for dim in range(3):
        i_edge[dim] = np.where(available_region[:, :, :, dim] > 0)
        # Previously, we only selected voxels on the edges (CannyEdgeDetection), but now we use all voxels.
    if (len(i_edge[0][0]) == 0) or (len(i_edge[1][0]) == 0) or (len(
            i_edge[2][0]) == 0):
        logging.debug(
            'dvf_generation: We are out of points. Plz change the threshold value of Canny method!!!!! '
        )  # Old method. only edges!
    while (len(i_edge[0][0]) > 4) and (len(i_edge[1][0]) > 4) and (len(
            i_edge[2][0]) > 4) and (itr_edge < number_dilation):
        # i_edge will change at the end of this while loop!
        no_more_dilatation_in_this_region = False
        for dim in range(3):
            select_voxel[dim] = int(
                random_state.randint(0,
                                     len(i_edge[dim][0]) - 1,
                                     1,
                                     dtype=np.int64))
            block_low[dim], block_high[dim] = center_to_block(
                setting,
                center=np.array([
                    i_edge[dim][0][select_voxel[dim]],
                    i_edge[dim][1][select_voxel[dim]],
                    i_edge[dim][2][select_voxel[dim]]
                ]),
                radius=round(setting['deform_exp'][im_info['deform_exp']]
                             ['MixedFrequency_BlockRadius'] / stage),
                im_ref=im_sitk)
        if itr_edge == 0:
            struct = np.ones((3, 3, 3), dtype=bool)
            for dim in range(3):
                dilated_edge[:, :, :, dim] = ndimage.binary_dilation(
                    dilated_edge[:, :, :, dim], structure=struct)

        elif itr_edge < np.round(
                10 * number_dilation / 12
        ):  # We like to include zero deformation in our training set.
            no_more_dilatation_in_this_region = True
            for dim in range(3):
                dilated_edge[block_low[dim][0]:block_high[dim][0],
                             block_low[dim][1]:block_high[dim][1],
                             block_low[dim][2]:block_high[dim][2], dim] = False

        elif itr_edge < np.round(11 * number_dilation / 12):
            struct = ndimage.generate_binary_structure(3, 2)
            for dim in range(3):
                mask_for_edge_dilation = np.zeros(np.shape(
                    dilated_edge[:, :, :, dim]),
                                                  dtype=bool)
                mask_for_edge_dilation[
                    block_low[dim][0]:block_high[dim][0],
                    block_low[dim][1]:block_high[dim][1],
                    block_low[dim][2]:block_high[dim][2]] = True
                dilated_edge[:, :, :, dim] = ndimage.binary_dilation(
                    dilated_edge[:, :, :, dim],
                    structure=struct,
                    mask=mask_for_edge_dilation)
            if (itr_edge % 2) == 0:
                no_more_dilatation_in_this_region = True
        elif itr_edge < number_dilation:
            struct = np.zeros((9, 9, 9), dtype=bool)
            if (itr_edge % 3) == 0:
                struct[0:5, :, :] = True
            if (itr_edge % 3) == 1:
                struct[:, 0:5, :] = True
            if (itr_edge % 3) == 2:
                struct[:, :, 0:5] = True
            for dim in range(3):
                mask_for_edge_dilation = np.zeros(np.shape(
                    dilated_edge[:, :, :, dim]),
                                                  dtype=bool)
                mask_for_edge_dilation[
                    block_low[dim][0]:block_high[dim][0],
                    block_low[dim][1]:block_high[dim][1],
                    block_low[dim][2]:block_high[dim][2]] = True
                dilated_edge[:, :, :, dim] = ndimage.binary_dilation(
                    dilated_edge[:, :, :, dim],
                    structure=struct,
                    mask=mask_for_edge_dilation)
            if random_state.uniform() > 0.3:
                no_more_dilatation_in_this_region = True
        if no_more_dilatation_in_this_region:
            available_region[block_low[dim][0]:block_high[dim][0],
                             block_low[dim][1]:block_high[dim][1],
                             block_low[dim][2]:block_high[dim][2], dim] = False
        if itr_edge >= np.round(10 * number_dilation / 12):
            for dim in range(3):
                i_edge[dim] = np.where(available_region[:, :, :, dim] > 0)
        itr_edge += 1

    bcoeff = bspline_coeff(im_sitk,
                           max_deform,
                           grid_border_to_zero,
                           grid_smoothing_sigma,
                           grid_spacing,
                           random_state,
                           dim_im=3,
                           artificial_generation='mixed_frequency')
    dvf_filter = sitk.TransformToDisplacementFieldFilter()
    dvf_filter.SetSize(im_sitk.GetSize())
    smoothed_values_sitk = dvf_filter.Execute(bcoeff)
    smoothed_values = sitk.GetArrayFromImage(smoothed_values_sitk)

    dvf = (dilated_edge.astype(np.float64) * smoothed_values).astype(
        np.float64)
    if setting['DVFPad_S' + str(stage)] > 0:
        pad = setting['DVFPad_S' + str(stage)]
        dvf = np.pad(dvf, ((pad, pad), (pad, pad), (pad, pad), (0, 0)),
                     'constant',
                     constant_values=(0, ))

    sigma_range = setting['deform_exp'][
        im_info['deform_exp']]['MixedFrequency_SigmaRange'][deform_number]
    sigma = random_state.uniform(low=sigma_range[0] / stage,
                                 high=sigma_range[1] / stage,
                                 size=3)
    dvf = smooth_dvf(dvf,
                     sigma_blur=sigma,
                     parallel_processing=setting['ParallelSearching'])

    if setting['deform_exp'][im_info['deform_exp']]['DVFNormalization']:
        dvf = normalize_dvf(dvf, max_deform)

    return dvf