Example #1
0
def cylinder_mask(setting, cn=None, overwrite=False):
    cylinder_folder = su.address_generator(setting,
                                           'Cylinder',
                                           cn=cn,
                                           type_im=0).rsplit('/',
                                                             maxsplit=1)[0]
    if not os.path.isdir(cylinder_folder):
        os.makedirs(cylinder_folder)
    for type_im in range(len(setting['types'])):
        cylinder_mask_address = su.address_generator(setting,
                                                     'Cylinder',
                                                     cn=cn,
                                                     type_im=type_im)
        if (not os.path.isfile(cylinder_mask_address)) or overwrite:
            image_sitk = sitk.ReadImage(
                su.address_generator(setting, 'Im', cn=cn, type_im=type_im))
            cylinder_mask_sitk = sitk.BinaryThreshold(
                image_sitk,
                lowerThreshold=setting['defaultPixelValue'] - 1,
                upperThreshold=setting['defaultPixelValue'] + 0.01,
                insideValue=0,
                outsideValue=1)
            structure = np.ones((1, 3, 3))
            # erosion with ndimage is 5 times faster than SimpleITK
            cylinder_mask_eroded = (ndimage.binary_erosion(
                sitk.GetArrayFromImage(cylinder_mask_sitk),
                structure=structure,
                iterations=2)).astype(np.int8)
            cylinder_mask_eroded_sitk = ip.array_to_sitk(cylinder_mask_eroded,
                                                         im_ref=image_sitk)
            sitk.WriteImage(cylinder_mask_eroded_sitk, cylinder_mask_address)
            logging.debug(cylinder_mask_address + ' is done')
Example #2
0
 def run(self):
     while self._semi_epoch == 0:
         ishuffled_address = su.address_generator(
             self._setting,
             'IShuffled',
             train_mode=self._train_mode,
             number_of_images_per_chunk=self._number_of_images_per_chunk,
             samples_per_image=self._samples_per_image,
             semi_epoch=self._semi_epoch,
             chunk=self._chunk,
             stage=self._stage)
         while os.path.isfile(ishuffled_address) and not self._semi_epoch:
             logging.debug(
                 'Direct1stEpoch: for stage={}, semiEpoch={}, Chunk={} is already generated, going to next chunk'
                 .format(self._semi_epoch, self._stage, self._chunk))
             self.go_to_next_chunk_without_going_to_fill()
             ishuffled_address = su.address_generator(
                 self._setting,
                 'IShuffled',
                 train_mode=self._train_mode,
                 number_of_images_per_chunk=self.
                 _number_of_images_per_chunk,
                 samples_per_image=self._samples_per_image,
                 semi_epoch=self._semi_epoch,
                 chunk=self._chunk,
                 stage=self._stage)
         if not self._semi_epoch:
             self.fill()
             self.go_to_next_chunk()
     logging.debug(
         'Direct1stEpoch: exiting . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .'
     )
Example #3
0
def resampling(data, spacing=None, requested_im_list=None):
    if spacing is None:
        spacing = [1, 1, 1]
    if requested_im_list is None:
        requested_im_list = ['Im']
    data_exp_dict = [{'data': data, 'deform_exp': '3D_max25_D12'}]
    setting = su.initialize_setting('')
    setting = su.load_setting_from_data_dict(setting, data_exp_dict)

    for type_im in range(len(setting['data'][data]['types'])):
        for cn in setting['data'][data]['CNList']:
            im_info_su = {'data': data, 'type_im': type_im, 'cn': cn}
            for requested_im in requested_im_list:
                if requested_im == 'Im':
                    interpolator = sitk.sitkBSpline
                elif requested_im in ['Mask', 'Torso']:
                    interpolator = sitk.sitkNearestNeighbor
                else:
                    raise ValueError(
                        'interpolator is only defined for ["Im", "Mask", "Torso"] not for '
                        + requested_im)
                im_raw_sitk = sitk.ReadImage(
                    su.address_generator(setting,
                                         'original' + requested_im + 'Raw',
                                         **im_info_su))
                im_sitk = ip.resampler_sitk(im_raw_sitk,
                                            spacing,
                                            default_pixel_value=setting['data']
                                            [data]['defaultPixelValue'],
                                            interpolator=interpolator,
                                            dimension=3)
                sitk.WriteImage(
                    im_sitk,
                    su.address_generator(setting, 'original' + requested_im,
                                         **im_info_su))
Example #4
0
def landmarks_from_dvf_old(setting, IN_test_list):
    # %%------------------------------------------- Setting of generating synthetic DVFs------------------------------------------
    saved_file = su.address_generator(setting, 'landmarks_file')
    if os.path.isfile(saved_file):
        raise ValueError('cannot overwrite, please change the name of the pickle file: ' + saved_file)

    # %%------------------------------------------------------  running   ---------------------------------------------------------
    landmarks = [{} for _ in range(22)]
    for IN in IN_test_list:
        image_pair = real_pair.Images(IN, setting=setting)
        image_pair.prepare_for_landmarks(padding=False)
        landmarks[IN]['FixedLandmarksWorld'] = image_pair._fixed_landmarks_world.copy()
        landmarks[IN]['MovingLandmarksWorld'] = image_pair._moving_landmarks_world.copy()
        landmarks[IN]['FixedAfterAffineLandmarksWorld'] = image_pair._fixed_after_affine_landmarks_world.copy()
        landmarks[IN]['DVFAffine'] = image_pair._dvf_affine.copy()
        landmarks[IN]['DVF_nonrigidGroundTruth'] = image_pair._moving_landmarks_world - image_pair._fixed_after_affine_landmarks_world
        landmarks[IN]['FixedLandmarksIndex'] = image_pair._fixed_landmarks_index
        dvf_s0_sitk = sitk.ReadImage(su.address_generator(setting, 'dvf_s0', cn=IN))
        dvf_s0 = sitk.GetArrayFromImage(dvf_s0_sitk)
        landmarks[IN]['DVFRegNet'] = np.stack([dvf_s0[image_pair._fixed_landmarks_index[i, 2],
                                                      image_pair._fixed_landmarks_index[i, 1],
                                                      image_pair._fixed_landmarks_index[i, 0]]
                                               for i in range(len(image_pair._fixed_landmarks_index))])
        print('CPU: IN = {} is done in '.format(IN))
    with open(saved_file, 'wb') as f:
        pickle.dump(landmarks, f)
Example #5
0
def affine_transformix_torso(setting, pair_info, stage=1, overwrite=False):
    im_info_fixed = copy.deepcopy(pair_info[0])
    im_info_moving = copy.deepcopy(pair_info[1])
    im_info_fixed['stage'] = stage
    im_info_moving['stage'] = stage

    moved_torso_affine_address = su.address_generator(setting,
                                                      'MovedTorsoAffine',
                                                      pair_info=pair_info,
                                                      **im_info_fixed)
    if os.path.isfile(moved_torso_affine_address):
        if overwrite:
            logging.debug('Affine Torso overwriting... Data=' +
                          pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                              pair_info[0]['cn'], pair_info[0]['type_im']))
        else:
            logging.debug('Affine Torso skipping... Data=' +
                          pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                              pair_info[0]['cn'], pair_info[0]['type_im']))
            return 0
    else:
        logging.debug('Affine Torso starting... Data=' +
                      pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                          pair_info[0]['cn'], pair_info[0]['type_im']))

    affine_folder = su.address_generator(setting,
                                         'AffineFolder',
                                         pair_info=pair_info,
                                         **im_info_fixed)
    output_directory = affine_folder + 'torso_moved/'
    if not os.path.isdir(output_directory):
        os.makedirs(output_directory)
    parameter_old_address = affine_folder + 'TransformParameters.0.txt'
    parameter_new_address = output_directory + 'TransformParameters.0.txt'
    with open(parameter_old_address, "r") as text_string:
        parameter = text_string.read()
        parameter = parameter.replace('FinalBSplineInterpolationOrder 3',
                                      'FinalBSplineInterpolationOrder 0')
        parameter = parameter.replace('ResultImagePixelType "short"',
                                      'ResultImagePixelType "char"')
    with open(parameter_new_address, "w") as text_string:
        text_string.write(parameter)

    elxpy.transformix(parameter_file=parameter_new_address,
                      output_directory=output_directory,
                      transformix_address='transformix',
                      input_image=su.address_generator(setting,
                                                       'originalTorso',
                                                       **im_info_moving),
                      threads=setting['reg_NumberOfThreads'])
    old_moved_torso_affine_address = output_directory + 'result.mha'
    os.rename(old_moved_torso_affine_address, moved_torso_affine_address)
def lung_fill_hole_erode(setting, cn=None):
    folder = su.address_generator(setting, 'Lung_Filled_Erode', cn=cn, type_im=0).rsplit('/', maxsplit=1)[0]
    if not os.path.isdir(folder):
        os.makedirs(folder)
    for type_im in range(len(setting['types'])):
        lung_raw_filled_sitk = sitk.ReadImage(su.address_generator(setting, 'Lung_Filled', cn=cn, type_im=type_im))
        lung_raw_filled = sitk.GetArrayFromImage(lung_raw_filled_sitk)

        lung_raw_filled = lung_raw_filled > 0
        structure = np.ones([3, 3, 3], dtype=np.bool)
        lung_filled_erode = (ndimage.morphology.binary_dilation(lung_raw_filled, structure=structure)).astype(np.int8)
        sitk.WriteImage(ip.array_to_sitk(lung_filled_erode, im_ref=lung_raw_filled_sitk),
                        su.address_generator(setting, 'Lung_Filled_Erode', cn=cn, type_im=type_im))
Example #7
0
 def __init__(
     self,
     setting=None,
     number_of_images_per_chunk=None,  # number of images that I would like to load in RAM
     samples_per_image=None,
     im_info_list_full=None,
     train_mode=None,
     semi_epoch=0,
     stage=None,
 ):
     self._setting = setting
     self._number_of_images_per_chunk = number_of_images_per_chunk
     self._samples_per_image = samples_per_image
     self._chunk = 0
     self._chunks_completed = 0
     self._semi_epochs_completed = 0
     self._semi_epoch = semi_epoch
     self._batch_counter = 0
     self._dvf_list = [None] * number_of_images_per_chunk
     self._im_info_list_full = im_info_list_full
     self._train_mode = train_mode
     if stage is None:
         stage = setting['stage']
     self._stage = stage
     self._ishuffled = None
     ishuffled_folder = su.address_generator(setting,
                                             'IShuffledFolder',
                                             train_mode=train_mode,
                                             stage=stage)
     if not (os.path.isdir(ishuffled_folder)):
         os.makedirs(ishuffled_folder)
def initialize(current_experiment, stage_list):
    parser = argparse.ArgumentParser(description='read where_to_run')
    parser.add_argument(
        '--where_to_run',
        '-w',
        help=
        'This is an optional argument, you choose between "Auto" or "Cluster". The default value is "Auto"'
    )
    args = parser.parse_args()
    where_to_run = args.where_to_run
    setting = su.initialize_setting(current_experiment=current_experiment,
                                    where_to_run=where_to_run)
    backup_number = 1
    backup_root_folder = su.address_generator(setting,
                                              'result_step_folder',
                                              stage_list=stage_list)
    backup_folder = backup_root_folder + 'backup-' + str(backup_number) + '/'
    while os.path.isdir(backup_folder):
        backup_number = backup_number + 1
        backup_folder = backup_root_folder + 'backup-' + str(
            backup_number) + '/'
    gut.logger.set_log_file(backup_folder + 'log.txt', short_mode=True)
    shutil.copy(
        os.path.realpath(__file__),
        backup_folder + os.path.realpath(__file__).rsplit('/', maxsplit=1)[1])
    return setting, backup_folder
Example #9
0
def affine_transformix_points(setting, pair_info, stage=1, overwrite=False):
    """
    In this function we transform the points (index or world) by affine transform. This function
    utilizes transformix. However, it is also possible to read the affine parameters and do the math.
    :param setting:
    :param pair_info:
    :param stage:
    :param overwrite:
    :return:
    """
    im_info_fixed = copy.deepcopy(pair_info[0])
    im_info_fixed['stage'] = stage

    affine_output_points = su.address_generator(setting,
                                                'reg_AffineOutputPoints',
                                                pair_info=pair_info,
                                                **im_info_fixed)
    if os.path.isfile(affine_output_points):
        if overwrite:
            logging.debug('Affine transformix overwriting... Data=' +
                          pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                              pair_info[0]['cn'], pair_info[0]['type_im']))
        else:
            logging.debug('Affine transformix skipping... Data=' +
                          pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                              pair_info[0]['cn'], pair_info[0]['type_im']))
            return 0
    else:
        logging.debug('Affine transformix starting... Data=' +
                      pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                          pair_info[0]['cn'], pair_info[0]['type_im']))
    fixed_landmarks_point_elx_address = su.address_generator(
        setting, 'LandmarkPoint_elx', pair_info=pair_info, **im_info_fixed)
    affine_folder = su.address_generator(setting,
                                         'AffineFolder',
                                         pair_info=pair_info,
                                         **im_info_fixed)
    elxpy.transformix(parameter_file=affine_folder +
                      'TransformParameters.0.txt',
                      output_directory=affine_folder,
                      points=fixed_landmarks_point_elx_address,
                      transformix_address='transformix',
                      threads=setting['reg_NumberOfThreads'])
Example #10
0
def regressionPlot(y_plot, yHat_plot, itr, batchSizeTrain, plot_mode, setting=None):
    if not (os.path.isdir(su.address_generator(setting, 'Plots_folder'))):
        os.makedirs(su.address_generator(setting, 'Plots_folder'))

    y_dir_plot = np.empty([(np.shape(y_plot[:, :, :, :, 0].flatten()))[0], 3])
    yHat_dir_plot = np.empty([(np.shape(yHat_plot[:, :, :, :, 0].flatten()))[0], 3])

    for i in range(3):
        try:
            y_dir_plot[:, i] = y_plot[:, :, :, :, i].flatten()
            yHat_dir_plot[:, i] = yHat_plot[:, :, :, :, i].flatten()
            plt.figure(figsize=(22, 12))
            sort_indices = np.argsort(y_dir_plot[:, i])
            plt.plot(yHat_dir_plot[:, i][sort_indices], label='RegNet dir' + str(i) + '_itr' + str(itr * batchSizeTrain))
            plt.plot(y_dir_plot[:, i][sort_indices], label='y dir' + str(i) + '_itr' + str(itr * batchSizeTrain))
            plt.legend(bbox_to_anchor=(1., .8))
            plt.ylim((-22, 22))
            plt.draw()
            plt.savefig(su.address_generator(setting, 'plot_fig', plot_mode=plot_mode, plot_itr=itr * batchSizeTrain, plot_i=i))
            plt.close()
        except:
            print('error in plotting... ')
            pass
Example #11
0
def affine(setting, pair_info, parameter_name, stage=1, overwrite=False):

    im_info_fixed = copy.deepcopy(pair_info[0])
    im_info_moving = copy.deepcopy(pair_info[1])
    im_info_fixed['stage'] = stage
    im_info_moving['stage'] = stage

    affine_moved_address = su.address_generator(setting,
                                                'MovedImAffine',
                                                pair_info=pair_info,
                                                **im_info_fixed)
    if os.path.isfile(affine_moved_address):
        if overwrite:
            logging.debug('Affine registration overwriting... Data=' +
                          pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                              pair_info[0]['cn'], pair_info[0]['type_im']))
        else:
            logging.debug('Affine registration skipping... Data=' +
                          pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                              pair_info[0]['cn'], pair_info[0]['type_im']))
            return 0
    else:
        logging.debug('Affine registration starting... Data=' +
                      pair_info[0]['data'] + ' CN = {} TypeIm = {}'.format(
                          pair_info[0]['cn'], pair_info[0]['type_im']))
    fixed_mask_address = None
    moving_mask_address = None
    initial_transform = None
    if setting['reg_UseMask']:
        fixed_mask_address = su.address_generator(
            setting, setting['MaskName_Affine'][0], **im_info_fixed)
        moving_mask_address = su.address_generator(
            setting, setting['MaskName_Affine'][0], **im_info_moving)
    elxpy.elastix(parameter_file=su.address_generator(
        setting, 'ParameterFolder', pair_info=pair_info, **im_info_fixed) +
                  parameter_name,
                  output_directory=su.address_generator(setting,
                                                        'AffineFolder',
                                                        pair_info=pair_info,
                                                        **im_info_fixed),
                  elastix_address='elastix',
                  fixed_image=su.address_generator(setting, 'originalIm',
                                                   **im_info_fixed),
                  moving_image=su.address_generator(setting, 'originalIm',
                                                    **im_info_moving),
                  fixed_mask=fixed_mask_address,
                  moving_mask=moving_mask_address,
                  initial_transform=initial_transform,
                  threads=setting['reg_NumberOfThreads'])
Example #12
0
    def load_pair(self):
        im_info_fixed = copy.deepcopy(self._pair_info[0])
        im_info_fixed['stage'] = self._stage
        im_info_moving = copy.deepcopy(self._pair_info[1])
        im_info_moving['stage'] = self._stage
        fixed_torso_sitk = None
        moved_torso_affine_sitk = None

        if 'dsmooth' in im_info_moving:
            # in this case it means that images are synthetic
            fixed_im_sitk = sitk.ReadImage(
                su.address_generator(self._setting, 'Im', **im_info_fixed))
            moved_im_affine_sitk = sitk.ReadImage(
                su.address_generator(self._setting, 'deformedIm',
                                     **im_info_moving))
            if self._setting['torsoMask']:
                fixed_torso_sitk = sitk.ReadImage(
                    su.address_generator(self._setting, 'Torso',
                                         **im_info_fixed))
                moved_torso_affine_sitk = sitk.ReadImage(
                    su.address_generator(self._setting, 'deformedTorso',
                                         **im_info_moving))
        else:
            fixed_im_sitk = sitk.ReadImage(
                su.address_generator(self._setting, 'originalIm',
                                     **im_info_fixed))
            moved_im_affine_sitk = sitk.ReadImage(
                su.address_generator(self._setting,
                                     'MovedImAffine',
                                     pair_info=self._pair_info,
                                     **im_info_moving))
            if self._setting['torsoMask']:
                fixed_torso_sitk = sitk.ReadImage(
                    su.address_generator(self._setting, 'originalTorso',
                                         **im_info_fixed))
                moved_torso_affine_sitk = sitk.ReadImage(
                    su.address_generator(self._setting,
                                         'MovedTorsoAffine',
                                         pair_info=self._pair_info,
                                         **im_info_moving))

        return fixed_im_sitk, moved_im_affine_sitk, fixed_torso_sitk, moved_torso_affine_sitk
Example #13
0
def landmarks_from_dvf(setting, pair_info):
    stage_list = setting['ImagePyramidSchedule']
    pair = real_pair.Images(setting, pair_info, stage=1)
    pair.prepare_for_landmarks(padding=False)
    dvf_s0 = sitk.GetArrayFromImage(sitk.ReadImage(
        su.address_generator(setting, 'dvf_s0', pair_info=pair_info, stage_list=stage_list)))
    current_landmark = {'setting': setting,
                        'pair_info': pair_info,
                        'FixedLandmarksWorld': pair._fixed_landmarks_world.copy(),
                        'MovingLandmarksWorld': pair._moving_landmarks_world.copy(),
                        'FixedAfterAffineLandmarksWorld': pair._fixed_after_affine_landmarks_world.copy(),
                        'DVFAffine': pair._dvf_affine.copy(),
                        'DVF_nonrigidGroundTruth': pair._moving_landmarks_world - pair._fixed_after_affine_landmarks_world,
                        'FixedLandmarksIndex': pair._fixed_landmarks_index.copy(),
                        'DVFRegNet': np.stack([dvf_s0[pair._fixed_landmarks_index[i, 2],
                                                      pair._fixed_landmarks_index[i, 1],
                                                      pair._fixed_landmarks_index[i, 0]]
                                               for i in range(len(pair._fixed_landmarks_index))])
                        }
    return current_landmark
Example #14
0
def calculate_landmark(setting, pair_info, network_dict, overwrite_landmarks=False):
    time_before = time.time()
    stage_list = setting['ImagePyramidSchedule']
    landmark_address = su.address_generator(setting, 'landmarks_file', stage_list=stage_list)
    if os.path.isfile(landmark_address):
        with open(landmark_address, 'rb') as f:
            landmark = pickle.load(f)
    else:
        landmark = []

    if any([(dict(sorted(pair_info[0].items())) == dict(sorted(landmark_i['pair_info'][0].items()))) and
            (dict(sorted(pair_info[1].items())) == dict(sorted(landmark_i['pair_info'][1].items())))
            for landmark_i in landmark]):
        # the above is just a simple comparison of two dict. It should be noted that if the keys are not sorted,
        # then == results in False value.
        if not overwrite_landmarks:
            logging.debug('Landmark skipping stage_list={}'.format(stage_list) + ' Fixed: ' + pair_info[0]['data'] +
                          '_CN{}_TypeIm{},'.format(pair_info[0]['cn'], pair_info[0]['type_im']) + '  Moving:' +
                          pair_info[1]['data'] + '_CN{}_TypeIm{}'.format(
                pair_info[1]['cn'], pair_info[1]['type_im']))
            return 1
        else:
            logging.debug('Landmark overwriting stage_list={}'.format(stage_list) + ' Fixed: ' + pair_info[0]['data'] +
                          '_CN{}_TypeIm{},'.format(pair_info[0]['cn'], pair_info[0]['type_im']) + '  Moving:' +
                          pair_info[1]['data'] + '_CN{}_TypeIm{}'.format(
                pair_info[1]['cn'], pair_info[1]['type_im']))

    landmark_dict = landmarks_from_dvf(setting, pair_info)
    landmark_dict['network_dict'] = copy.deepcopy(network_dict)
    landmark.append(landmark_dict)

    with open(landmark_address, 'wb') as f:
        pickle.dump(landmark, f)
    time_after = time.time()
    logging.debug('Landmark stage_list={}'.format(stage_list)+' Fixed: '+pair_info[0]['data']+
                  '_CN{}_TypeIm{},'.format(pair_info[0]['cn'], pair_info[0]['type_im'])+'  Moving:' +
                  pair_info[1]['data']+'_CN{}_TypeIm{} is done in {:.2f}s '.format(
        pair_info[1]['cn'], pair_info[1]['type_im'], time_after - time_before))
    return landmark
Example #15
0
    def __init__(
        self,
        setting=None,
        number_of_images_per_chunk=None,  # number of images that I would like to load in RAM
        samples_per_image=None,
        im_info_list_full=None,
        train_mode=None,
        semi_epoch=0,
        stage=None,
    ):
        threading.Thread.__init__(self)
        self.paused = False
        self.pause_cond = threading.Condition(threading.Lock())
        self.daemon = True

        self._setting = setting
        self._number_of_images_per_chunk = number_of_images_per_chunk
        self._samples_per_image = samples_per_image
        self._chunk = 0
        self._chunks_completed = 0
        self._semi_epochs_completed = 0
        self._semi_epoch = semi_epoch
        self._batch_counter = 0
        self._fixed_im_list = [None] * number_of_images_per_chunk
        self._deformed_im_list = [None] * number_of_images_per_chunk
        self._dvf_list = [None] * number_of_images_per_chunk
        self._im_info_list_full = im_info_list_full
        self._train_mode = train_mode
        self._filled = 0
        if stage is None:
            stage = setting['stage']
        self._stage = stage
        self._ishuffled = None
        ishuffled_folder = su.address_generator(setting,
                                                'IShuffledFolder',
                                                train_mode=train_mode,
                                                stage=stage)
        if not (os.path.isdir(ishuffled_folder)):
            os.makedirs(ishuffled_folder)
Example #16
0
    def fill(self):
        self._filled = 0
        number_of_images_per_chunk = self._number_of_images_per_chunk
        if self._train_mode == 'Training':
            # Make all lists empty in training mode. In the validation or test mode, we keep the same chunk forever. So no need to make it empty and refill it again.
            self._fixed_im_list = [None] * number_of_images_per_chunk
            self._deformed_im_list = [None] * number_of_images_per_chunk
            self._dvf_list = [None] * number_of_images_per_chunk

        if self._semi_epochs_completed:
            self._semi_epoch = self._semi_epoch + 1
            self._semi_epochs_completed = 0
            self._chunk = 0

        im_info_list_full = copy.deepcopy(self._im_info_list_full)
        np.random.seed(self._semi_epoch)
        if self._setting['Randomness']:
            random_indices = np.random.permutation(len(im_info_list_full))
        else:
            random_indices = np.arange(len(im_info_list_full))

        lower_range = self._chunk * number_of_images_per_chunk
        upper_range = (self._chunk + 1) * number_of_images_per_chunk
        if upper_range >= len(im_info_list_full):
            upper_range = len(im_info_list_full)
            self._semi_epochs_completed = 1
            number_of_images_per_chunk = upper_range - lower_range  # In cases when last chunk of images are smaller than the self._numberOfImagesPerChunk
            self._fixed_im_list = [None] * number_of_images_per_chunk
            self._deformed_im_list = [None] * number_of_images_per_chunk
            self._dvf_list = [None] * number_of_images_per_chunk

        torso_list = [None] * len(self._dvf_list)
        indices_chunk = random_indices[lower_range:upper_range]
        im_info_list = [im_info_list_full[i] for i in indices_chunk]
        for i_index_im, index_im in enumerate(indices_chunk):
            self._fixed_im_list[i_index_im], self._deformed_im_list[i_index_im], self._dvf_list[i_index_im], torso_list[i_index_im] = \
                synth.get_dvf_and_deformed_images(self._setting,
                                                  im_info=im_info_list_full[index_im],
                                                  stage=self._stage,
                                                  mode_synthetic_dvf='reading'
                                                  )
            if self._setting['verbose']:
                logging.debug(
                    'thread: Data=' + im_info_list_full[index_im]['data'] +
                    ', TypeIm={}, CN={}, Dsmooth={}, stage={} is loaded'.
                    format(im_info_list_full[index_im]['type_im'],
                           im_info_list_full[index_im]['cn'],
                           im_info_list_full[index_im]['dsmooth'], self._stage)
                )
        ishuffled_address = su.address_generator(
            self._setting,
            'IShuffled',
            train_mode=self._train_mode,
            number_of_images_per_chunk=self._number_of_images_per_chunk,
            samples_per_image=self._samples_per_image,
            semi_epoch=self._semi_epoch,
            chunk=self._chunk,
            stage=self._stage)
        if self._semi_epoch == 0:
            # in semiEpoch = 0 we wait for the direct_1st_epoch to creat the ishuffled!
            countWait = 1
            while not os.path.isfile(ishuffled_address):
                time.sleep(5)
                logging.debug('thread: waiting {} s for IShuffled:'.format(
                    countWait * 5) + ishuffled_address)
                countWait += 1
            self._ishuffled = np.load(ishuffled_address)
        else:
            if os.path.isfile(ishuffled_address):
                self._ishuffled = np.load(ishuffled_address)
            else:
                self._ishuffled = reading_utils.shuffled_indices_from_chunk(
                    self._setting,
                    dvf_list=self._dvf_list,
                    torso_list=torso_list,
                    im_info_list=im_info_list,
                    stage=self._stage,
                    semi_epoch=self._semi_epoch,
                    chunk=self._chunk,
                    samples_per_image=self._samples_per_image,
                    number_of_images_per_chunk=number_of_images_per_chunk,
                    log_header='direct')
                np.save(ishuffled_address, self._ishuffled)

        self._filled = 1
        logging.debug('Thread is filled .....................')
        self.pause()
Example #17
0
def multi_stage(setting, network_dict, pair_info, overwrite=False):
    """
    :param setting:
    :param network_dict:
    :param pair_info: information of the pair to be registered.
    :param overwrite:
    :return: The output moved images and dvf will be written to the disk.
             1: registration is performed correctly
             2: skip overwriting
    """
    stage_list = setting['ImagePyramidSchedule']
    final_moved_image_address = su.address_generator(setting, 'moved_image', pair_info=pair_info, stage=0, stage_list=stage_list)
    if os.path.isfile(final_moved_image_address):
        if not overwrite:
            print('overwrite=False, file '+final_moved_image_address+' already exists, skipping .....')
            return 2
        else:
            print('overwrite=True, file '+final_moved_image_address+' already exists, but overwriting .....')

    pair_stage1 = real_pair.Images(setting, pair_info, stage=1)
    pyr = dict()  # pyr: a dictionary of pyramid images
    pyr['fixed_im_s1_sitk'] = pair_stage1.get_fixed_im_sitk()
    pyr['moving_im_s1_sitk'] = pair_stage1.get_moved_im_affine_sitk()
    pyr['fixed_im_s1'] = pair_stage1.get_fixed_im()
    pyr['moving_im_s1'] = pair_stage1.get_moved_im_affine()
    if setting['torsoMask']:
        pyr['fixed_torso_s1_sitk'] = pair_stage1.get_fixed_torso_sitk()
        pyr['moving_torso_s1_sitk'] = pair_stage1.get_moved_torso_affine_sitk()
    if not (os.path.isdir(su.address_generator(setting, 'full_reg_folder', pair_info=pair_info, stage_list=stage_list))):
        os.makedirs(su.address_generator(setting, 'full_reg_folder', pair_info=pair_info, stage_list=stage_list))

    time_before_dvf = time.time()
    for i_stage, stage in enumerate(setting['ImagePyramidSchedule']):
        if stage != 1:
            pyr['fixed_im_s'+str(stage)+'_sitk'] = ip.downsampler_gpu(pyr['fixed_im_s1_sitk'], stage,
                                                                      default_pixel_value=setting['data'][pair_info[0]['data']]['defaultPixelValue'])
            pyr['moving_im_s'+str(stage)+'_sitk'] = ip.downsampler_gpu(pyr['moving_im_s1_sitk'], stage,
                                                                       default_pixel_value=setting['data'][pair_info[1]['data']]['defaultPixelValue'])
        if setting['torsoMask']:
            pyr['fixed_torso_s'+str(stage)+'_sitk'] = ip.downsampler_sitk(pyr['fixed_torso_s1_sitk'],
                                                                          stage,
                                                                          im_ref=pyr['fixed_im_s' + str(stage) + '_sitk'],
                                                                          default_pixel_value=0,
                                                                          interpolator=sitk.sitkNearestNeighbor)
            pyr['moving_torso_s'+str(stage)+'_sitk'] = ip.downsampler_sitk(pyr['moving_torso_s1_sitk'],
                                                                           stage,
                                                                           im_ref=pyr['moving_im_s' + str(stage) + '_sitk'],
                                                                           default_pixel_value=0,
                                                                           interpolator=sitk.sitkNearestNeighbor)
        else:
            pyr['fixed_torso_s'+str(stage)+'_sitk'] = None
            pyr['moving_torso_s'+str(stage)+'_sitk'] = None
        input_regnet_moving_torso = None
        if i_stage == 0:
            input_regnet_moving = 'moving_im_s'+str(stage)+'_sitk'
            if setting['torsoMask']:
                input_regnet_moving_torso = 'moving_torso_s'+str(stage)+'_sitk'
        else:
            previous_pyramid = setting['ImagePyramidSchedule'][i_stage - 1]
            dvf_composed_previous_up_sitk = 'DVF_s'+str(previous_pyramid)+'_composed_up_sitk'
            dvf_composed_previous_sitk = 'DVF_s'+str(previous_pyramid)+'_composed_sitk'
            if i_stage == 1:
                pyr[dvf_composed_previous_sitk] = pyr['DVF_s'+str(setting['ImagePyramidSchedule'][i_stage-1])+'_sitk']
            elif i_stage > 1:
                pyr[dvf_composed_previous_sitk] = sitk.Add(pyr['DVF_s'+str(setting['ImagePyramidSchedule'][i_stage-2])+'_composed_up_sitk'],
                                                           pyr['DVF_s'+str(setting['ImagePyramidSchedule'][i_stage - 1])+'_sitk'])
            pyr[dvf_composed_previous_up_sitk] = ip.upsampler_gpu(pyr[dvf_composed_previous_sitk],
                                                                  round(previous_pyramid/stage),
                                                                  dvf_output_size=pyr['fixed_im_s'+str(stage)+'_sitk'].GetSize()[::-1],
                                                                  )
            if setting['WriteAfterEachStage']:
                sitk.WriteImage(sitk.Cast(pyr[dvf_composed_previous_up_sitk], sitk.sitkVectorFloat32),
                                su.address_generator(setting, 'dvf_s_up', pair_info=pair_info, stage=previous_pyramid, stage_list=stage_list))

            dvf_t = sitk.DisplacementFieldTransform(pyr[dvf_composed_previous_up_sitk])
            # after this line DVF_composed_previous_up_sitk is converted to a transform. so we need to load it again.
            pyr['moved_im_s'+str(stage)+'_sitk'] = ip.resampler_by_dvf(pyr['moving_im_s' + str(stage)+'_sitk'],
                                                                       dvf_t,
                                                                       default_pixel_value=setting['data'][pair_info[1]['data']]['defaultPixelValue'])
            if setting['torsoMask']:
                pyr['moved_torso_s'+str(stage)+'_sitk'] = ip.resampler_by_dvf(pyr['moving_torso_s'+str(stage)+'_sitk'],
                                                                              dvf_t,
                                                                              default_pixel_value=0,
                                                                              interpolator=sitk.sitkNearestNeighbor)
            pyr[dvf_composed_previous_up_sitk] = dvf_t.GetDisplacementField()
            if setting['WriteAfterEachStage']:
                sitk.WriteImage(sitk.Cast(pyr['moved_im_s'+str(stage)+'_sitk'], setting['data'][pair_info[1]['data']]['imageByte']),
                                su.address_generator(setting, 'moved_image', pair_info=pair_info, stage=stage, stage_list=stage_list))
            input_regnet_moving = 'moved_im_s'+str(stage)+'_sitk'
            if setting['torsoMask']:
                input_regnet_moving_torso = 'moved_torso_s'+str(stage)+'_sitk'

        pyr['DVF_s'+str(stage)] = np.zeros(np.r_[pyr['fixed_im_s'+str(stage)+'_sitk'].GetSize()[::-1], 3], dtype=np.float64)
        pair_pyramid = real_pair.Images(setting, pair_info, stage=stage,
                                        fixed_im_sitk=pyr['fixed_im_s'+str(stage)+'_sitk'],
                                        moved_im_affine_sitk=pyr[input_regnet_moving],
                                        fixed_torso_sitk=pyr['fixed_torso_s'+str(stage)+'_sitk'],
                                        moved_torso_affine_sitk=pyr[input_regnet_moving_torso]
                                        )

        # building and loading network
        tf.reset_default_graph()
        setting['R'] = network_dict['Stage'+str(stage)]['R']    # Radius of normal resolution patch size. Total size is (2*R +1)
        setting['Ry'] = network_dict['Stage'+str(stage)]['Ry']  # Radius of output. Total size is (2*Ry +1)
        setting['Ry_erode'] = network_dict['Stage'+str(stage)]['Ry_erode']  # at the test time, sometimes there are some problems at the border
        images_tf = tf.placeholder(tf.float32,
                                   shape=[None, 2 * setting['R'] + 1, 2 * setting['R'] + 1, 2 * setting['R'] + 1, 2],
                                   name="Images")
        bn_training = tf.placeholder(tf.bool, name='bn_training')
        x_fixed = images_tf[:, :, :, :, 0, np.newaxis]
        x_deformed = images_tf[:, :, :, :, 1, np.newaxis]
        dvf_tf = getattr(RegNetModel, network_dict['Stage'+str(stage)]['NetworkDesign'])(x_fixed, x_deformed, bn_training)
        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        logging.debug(' Total number of variables %s' % (np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, su.address_generator(setting, 'saved_model_with_step',
                                                 current_experiment=network_dict['Stage'+str(stage)]['NetworkLoad'],
                                                 step=network_dict['Stage'+str(stage)]['GlobalStepLoad']))
        while not pair_pyramid.get_sweep_completed():
            # The pyr[DVF_S] is the numpy DVF which will be filled. the dvf_np is an output
            # path from the network. We control the spatial location of both dvf in this function
            batch_im, win_center, win_r_before, win_r_after, predicted_begin, predicted_end = pair_pyramid.next_sweep_patch()
            time_before_gpu = time.time()
            [dvf_np] = sess.run([dvf_tf], feed_dict={images_tf: batch_im, bn_training: 0})
            time_after_gpu = time.time()
            logging.debug('GPU: Data='+pair_info[0]['data']+' CN = {} center = {} is done in {:.2f}s '.format(
                pair_info[0]['cn'], win_center, time_after_gpu - time_before_gpu))

            pyr['DVF_s'+str(stage)][win_center[0] - win_r_before[0]: win_center[0] + win_r_after[0],
                                    win_center[1] - win_r_before[1]: win_center[1] + win_r_after[1],
                                    win_center[2] - win_r_before[2]: win_center[2] + win_r_after[2], :] = \
                dvf_np[0, predicted_begin[0]:predicted_end[0], predicted_begin[1]:predicted_end[1], predicted_begin[2]:predicted_end[2], :]

        pyr['DVF_s'+str(stage)+'_sitk'] = ip.array_to_sitk(pyr['DVF_s' + str(stage)],
                                                           im_ref=pyr['fixed_im_s'+str(stage)+'_sitk'],
                                                           is_vector=True)

        if i_stage == (len(setting['ImagePyramidSchedule'])-1):
            # when all stages are finished, final dvf and moved image are written
            dvf_composed_final_sitk = 'DVF_s'+str(stage)+'_composed_sitk'
            if len(setting['ImagePyramidSchedule']) == 1:
                pyr[dvf_composed_final_sitk] = pyr['DVF_s'+str(stage)+'_sitk']
            else:
                pyr[dvf_composed_final_sitk] = sitk.Add(pyr['DVF_s'+str(setting['ImagePyramidSchedule'][-2])+'_composed_up_sitk'],
                                                        pyr['DVF_s'+str(stage)+'_sitk'])
            sitk.WriteImage(sitk.Cast(pyr[dvf_composed_final_sitk], sitk.sitkVectorFloat32),
                            su.address_generator(setting, 'dvf_s0', pair_info=pair_info, stage_list=stage_list))
            dvf_t = sitk.DisplacementFieldTransform(pyr[dvf_composed_final_sitk])
            pyr['moved_im_s0_sitk'] = ip.resampler_by_dvf(pyr['moving_im_s'+str(stage)+'_sitk'], dvf_t,
                                                          default_pixel_value=setting['data'][pair_info[1]['data']]['defaultPixelValue'])
            sitk.WriteImage(sitk.Cast(pyr['moved_im_s0_sitk'],
                                      setting['data'][pair_info[1]['data']]['imageByte']),
                            su.address_generator(setting, 'moved_image', pair_info=pair_info, stage=0, stage_list=stage_list))
    time_after_dvf = time.time()
    logging.debug('Data='+pair_info[0]['data']+' CN = {} ImType = {} is done in {:.2f}s '.format(
        pair_info[0]['cn'], pair_info[0]['type_im'],time_after_dvf - time_before_dvf))

    return 1
Example #18
0
        self.slices, rows, cols = X.shape
        self.ind = self.slices//2

        self.im = ax.imshow(self.X[self.ind, :, :], cmap=cmap, aspect=aspect)
        self.update()

    def onscroll(self, event):
        print("%s %s" % (event.button, event.step))
        if event.button == 'up':
            self.ind = (self.ind + 1) % self.slices
        else:
            self.ind = (self.ind - 1) % self.slices
        self.update()

    def update(self):
        self.im.set_data(self.X[self.ind, :, :])
        self.ax.set_ylabel('slice %s' % self.ind)
        self.im.axes.figure.canvas.draw()


if __name__ == '__main__':
    im = np.random.rand(20, 20, 40)
    current_experiment = ''
    setting = su.initialize_setting(current_experiment)
    setting['deformName'] = '3D_max15_D9'
    setting = su.load_setting_from_data_dict(setting)
    im_sitk = sitk.ReadImage(su.address_generator(setting, 'Im', type_im=0, cn=1))
    im = sitk.GetArrayFromImage(im_sitk)
    view3d_image(im, slice_axis=1)
    hi = 1
Example #19
0
    def fill(self):
        number_of_images_per_chunk = self._number_of_images_per_chunk
        if self._train_mode == 'Training':
            self._dvf_list = [None] * number_of_images_per_chunk
        if self._semi_epochs_completed:  # This never runs
            self._semi_epoch = self._semi_epoch + 1
            self._semi_epochs_completed = 0
            self._chunk = 0

        im_info_list_full = copy.deepcopy(self._im_info_list_full)
        np.random.seed(self._semi_epoch)
        if self._setting['Randomness']:
            random_indices = np.random.permutation(len(im_info_list_full))
        else:
            random_indices = np.arange(len(im_info_list_full))

        lower_range = self._chunk * number_of_images_per_chunk
        upper_range = (self._chunk + 1) * number_of_images_per_chunk
        if upper_range >= len(im_info_list_full):
            upper_range = len(im_info_list_full)
            self._semi_epochs_completed = 1
            number_of_images_per_chunk = upper_range - lower_range  # In cases when last chunk of images are smaller than the self._numberOfImagesPerChunk
            self._dvf_list = [None] * number_of_images_per_chunk

        torso_list = [None] * len(self._dvf_list)
        indices_chunk = random_indices[lower_range:upper_range]
        im_info_list = [im_info_list_full[i] for i in indices_chunk]
        for i_index_im, index_im in enumerate(indices_chunk):
            _, _, self._dvf_list[i_index_im], torso_list[i_index_im] = \
                synth.get_dvf_and_deformed_images(self._setting,
                                                  im_info=im_info_list_full[index_im],
                                                  stage=self._stage,
                                                  mode_synthetic_dvf='generation'
                                                  )
            if self._setting['verbose']:
                logging.debug(
                    'Direct1stEpoch: Data=' +
                    im_info_list_full[index_im]['data'] +
                    ', TypeIm={}, CN={}, Dsmooth={}, stage={} is loaded'.
                    format(im_info_list_full[index_im]['type_im'],
                           im_info_list_full[index_im]['cn'],
                           im_info_list_full[index_im]['dsmooth'], self._stage)
                )
        ishuffled_address = su.address_generator(
            self._setting,
            'IShuffled',
            train_mode=self._train_mode,
            number_of_images_per_chunk=self._number_of_images_per_chunk,
            samples_per_image=self._samples_per_image,
            semi_epoch=self._semi_epoch,
            chunk=self._chunk,
            stage=self._stage)
        if os.path.isfile(ishuffled_address):
            self._ishuffled = np.load(ishuffled_address)
        else:
            self._ishuffled = reading_utils.shuffled_indices_from_chunk(
                self._setting,
                dvf_list=self._dvf_list,
                torso_list=torso_list,
                im_info_list=im_info_list,
                stage=self._stage,
                semi_epoch=self._semi_epoch,
                chunk=self._chunk,
                samples_per_image=self._samples_per_image,
                number_of_images_per_chunk=number_of_images_per_chunk,
                log_header='direct')
            # self._ishuffled = 1
            np.save(ishuffled_address, self._ishuffled)
Example #20
0
def shuffled_indices_from_chunk(setting,
                                dvf_list=None,
                                torso_list=None,
                                im_info_list=None,
                                stage=None,
                                semi_epoch=None,
                                chunk=None,
                                samples_per_image=None,
                                number_of_images_per_chunk=None,
                                log_header=''):
    for single_dict in setting['DataExpDict']:
        iclass_folder = su.address_generator(
            setting,
            'IClassFolder',
            data=single_dict['data'],
            deform_exp=single_dict['deform_exp'],
            stage=stage)
        if not (os.path.isdir(iclass_folder)):
            os.makedirs(iclass_folder)

    margin = setting['Margin']
    class_balanced = setting['classBalanced']
    indices = {}
    start_time = time.time()
    if setting['ParallelProcessing']:
        num_cores = multiprocessing.cpu_count() - 2
        results = [None] * len(dvf_list) * len(class_balanced)
        count_iclass_loaded = 0
        for i_dvf, im_info in enumerate(im_info_list):
            for c in range(0, len(class_balanced)):
                iclass_address = su.address_generator(
                    setting,
                    'IClass',
                    data=im_info['data'],
                    deform_exp=im_info['deform_exp'],
                    cn=im_info['cn'],
                    type_im=im_info['type_im'],
                    dsmooth=im_info['dsmooth'],
                    c=c,
                    stage=stage)
                if os.path.isfile(iclass_address):
                    results[i_dvf * len(class_balanced) + c] = np.load(
                        iclass_address)  # double checked
                    count_iclass_loaded += 1
        if count_iclass_loaded != len(results):
            logging.debug(
                log_header +
                ': not all I1 found. start calculating... semiEpoch = {}, Chunk = {}, stage={}'
                .format(semi_epoch, chunk, stage))
            results = Parallel(n_jobs=num_cores)(
                delayed(search_indices)(dvf=dvf_list[i],
                                        torso=torso_list[i],
                                        c=c,
                                        class_balanced=class_balanced,
                                        margin=margin,
                                        dim_im=setting['Dim'])
                for i in range(0, len(dvf_list))
                for c in range(0, len(class_balanced)))
            for i_dvf, im_info in enumerate(im_info_list):
                for c in range(0, len(class_balanced)):
                    iclass_address = su.address_generator(
                        setting,
                        'IClass',
                        data=im_info['data'],
                        deform_exp=im_info['deform_exp'],
                        cn=im_info['cn'],
                        type_im=im_info['type_im'],
                        dsmooth=im_info['dsmooth'],
                        c=c,
                        stage=stage)
                    np.save(iclass_address,
                            results[i_dvf * len(class_balanced) +
                                    c])  # double checked
        for iresults in range(0, len(results)):
            i_dvf = iresults // (
                len(class_balanced)
            )  # first loop in the Parallel: for i in range(0, len(dvf_list))
            c = iresults % (
                len(class_balanced)
            )  # second loop in the Parallel: for j in range(0, len(class_balanced)+1)
            if (i_dvf == 0) or (len(indices['class' + str(c)]) == 0):
                indices['class' + str(c)] = np.array(
                    np.c_[results[iresults], i_dvf *
                          np.ones(len(results[iresults]), dtype=np.int32)])
            else:
                indices['class' + str(c)] = np.concatenate(
                    (indices['class' + str(c)],
                     np.array(np.c_[
                         results[iresults], i_dvf *
                         np.ones(len(results[iresults]), dtype=np.int32)])),
                    axis=0)
        del results
        end_time = time.time()
        if setting['verbose']:
            logging.debug(
                log_header +
                ' Parallel searching for {} classes is Done in {:.2f}s'.format(
                    len(class_balanced), end_time - start_time))
    else:
        for i_dvf, im_info in enumerate(im_info_list):
            mask = np.zeros(np.shape(dvf_list[i_dvf])[:-1], dtype=np.bool)
            mask[margin:-margin, margin:-margin, margin:-margin] = True
            if torso_list[i_dvf] is not None:
                mask = mask & torso_list[i_dvf]
            for c in range(0, len(class_balanced)):
                iclass_address = su.address_generator(
                    setting,
                    'IClass',
                    data=im_info['data'],
                    deform_exp=im_info['deform_exp'],
                    cn=im_info['cn'],
                    type_im=im_info['type_im'],
                    dsmooth=im_info['dsmooth'],
                    c=c,
                    stage=stage)
                if os.path.isfile(iclass_address):
                    i1 = np.load(iclass_address)
                else:
                    if c == 0:
                        # you can add a mask here to prevent selecting pixels twice!
                        i1 = np.ravel_multi_index(
                            np.where((np.all(
                                (np.abs(dvf_list[i_dvf]) < class_balanced[c]),
                                axis=3)) & mask),
                            np.shape(dvf_list[i_dvf])[:-1]).astype(np.int32)
                        # the output of np.where occupy huge part of memory! by converting it to a numpy array lots of memory can be saved!
                    if (c > 0) & (c < (len(class_balanced))):
                        if setting['Dim'] == '2D':
                            # in 2D experiments, the DVFList is still in 3D and for the third direction is set to 0. Here we use np.any() instead of np.all()
                            i1 = np.ravel_multi_index(
                                np.where((np.all((np.abs(dvf_list[i_dvf]) <
                                                  class_balanced[c]),
                                                 axis=3))
                                         & (np.any((np.abs(dvf_list[i_dvf]) >=
                                                    class_balanced[c - 1]),
                                                   axis=3)) & mask),
                                np.shape(dvf_list[i_dvf])[:-1]).astype(
                                    np.int32)
                        if setting['Dim'] == '3D':
                            i1 = np.ravel_multi_index(
                                np.where((np.all((np.abs(dvf_list[i_dvf]) <
                                                  class_balanced[c]),
                                                 axis=3))
                                         & (np.all((np.abs(dvf_list[i_dvf]) >=
                                                    class_balanced[c - 1]),
                                                   axis=3)) & mask),
                                np.shape(dvf_list[i_dvf])[:-1]).astype(
                                    np.int32)
                    np.save(iclass_address, i1)
                if (i_dvf == 0) or (len(indices['class' + str(c)]) == 0):
                    indices['class' + str(c)] = np.array(
                        np.c_[i1, i_dvf * np.ones(len(i1), dtype=np.int32)])
                else:
                    indices['class' + str(c)] = np.concatenate(
                        (indices['class' + str(c)],
                         np.array(np.c_[i1, i_dvf *
                                        np.ones(len(i1), dtype=np.int32)])),
                        axis=0)
                if setting['verbose']:
                    logging.debug(log_header +
                                  ': Finding classes done for i = {}, c = {} '.
                                  format(i_dvf, c))
        del i1
        end_time = time.time()
        if setting['verbose']:
            logging.debug(
                log_header +
                ': Searching for {} classes is Done in {:.2f}s'.format(
                    len(class_balanced) + 1, end_time - start_time))
    samples_per_chunk = samples_per_image * number_of_images_per_chunk
    sample_per_chunk_per_class = np.round(samples_per_chunk /
                                          (len(class_balanced)))
    number_samples_class = np.empty(len(class_balanced), dtype=np.int32)
    np.random.seed(semi_epoch * 10000 + chunk * 100 + stage)
    selected_indices = np.array([])
    for c, k in enumerate(indices.keys()):
        number_samples_class[c] = min(sample_per_chunk_per_class,
                                      np.shape(indices[k])[0])
        # it is possible to have different number in each class. However we perefer to have at least SamplePerChunkPerClass
        if np.shape(indices['class' + str(c)])[0] > 0:
            i1 = np.random.randint(0,
                                   high=np.shape(indices['class' + str(c)])[0],
                                   size=number_samples_class[c])
            if c == 0 or len(selected_indices) == 0:
                selected_indices = np.concatenate(
                    (indices['class' + str(c)][i1, :],
                     c * np.ones([len(i1), 1], dtype=np.int32)),
                    axis=1).astype(np.int32)
            else:
                selected_indices = np.concatenate(
                    (selected_indices,
                     np.concatenate(
                         (indices['class' + str(c)][i1, :],
                          c * np.ones([len(i1), 1], dtype=np.int32)),
                         axis=1)),
                    axis=0)
        else:
            logging.info(
                log_header +
                ': no samples in class {} for semiEpoch = {}, Chunk = {} '.
                format(c, semi_epoch, chunk))
    if setting['verbose']:
        logging.debug(log_header +
                      ': samplesPerChunk is {} for semiEpoch = {}, Chunk = {} '
                      .format(sum(number_samples_class), semi_epoch, chunk))
    shuffled_index = np.arange(0, len(selected_indices))
    np.random.shuffle(shuffled_index)
    return selected_indices[shuffled_index]
Example #21
0
    def prepare_for_landmarks(self, padding=True):
        im_info_fixed = copy.deepcopy(self._pair_info[0])
        im_info_moving = copy.deepcopy(self._pair_info[1])
        fixed_landmarks_world = np.loadtxt(
            su.address_generator(self._setting,
                                 'LandmarkPoint_tr',
                                 pair_info=self._pair_info,
                                 **im_info_fixed))
        moving_landmarks_world = np.loadtxt(
            su.address_generator(self._setting,
                                 'LandmarkPoint_tr',
                                 pair_info=self._pair_info,
                                 **im_info_moving))
        if self._setting['data'][self._pair_info[0]
                                 ['data']]['UnsureLandmarkAvailable']:
            fixed_landmarks_unsure_list = np.loadtxt(
                su.address_generator(self._setting, 'UnsurePoints',
                                     **im_info_fixed))
            index_sure = []
            for i in range(len(fixed_landmarks_unsure_list)):
                if fixed_landmarks_unsure_list[i] == 0:
                    index_sure.append(i)
        else:
            index_sure = [i for i in range(len(fixed_landmarks_world))]
        self._fixed_landmarks_world = fixed_landmarks_world[
            index_sure]  # xyz order
        self._moving_landmarks_world = moving_landmarks_world[
            index_sure]  # xyz order
        self._fixed_landmarks_index = (np.round(
            (self._fixed_landmarks_world - self._fixed_im_sitk.GetOrigin()) /
            np.array(self._fixed_im_sitk.GetSpacing()))).astype(
                np.int16)  # xyz order

        if self._setting['data'][self._pair_info[1]
                                 ['data']]['AffineRegistration']:
            elx_all_points = elxpy.elxReadOutputPointsFile(
                su.address_generator(self._setting,
                                     'reg_AffineOutputPoints',
                                     pair_info=self._pair_info,
                                     **im_info_fixed))
            self._fixed_after_affine_landmarks_world = elx_all_points.OutputPoint[
                index_sure]
            self._dvf_affine = elx_all_points.Deformation[index_sure]
        else:
            self._fixed_after_affine_landmarks_world = self._fixed_landmarks_world.copy(
            )
            self._dvf_affine = [[0, 0, 0] for _ in range(len(index_sure))]

        # np.array(self._fixed_im_sitk.GetSpacing())          #xyz order
        # self._fixed_im_sitk.GetOrigin()                     #xyz order

        # The following lines are only used to check the dimension order.
        check_dimension_order = False
        if check_dimension_order:
            dilated_landmarks = np.zeros(np.shape(self._fixed_im),
                                         dtype=np.int8)
            rd = 7  # radius of dilation
            for i in range(np.shape(self._fixed_landmarks_index)[0]):
                dilated_landmarks[self._fixed_landmarks_index[i, 2] -
                                  rd:self._fixed_landmarks_index[i, 2] + rd,
                                  self._fixed_landmarks_index[i, 1] -
                                  rd:self._fixed_landmarks_index[i, 1] + rd,
                                  self._fixed_landmarks_index[i, 0] -
                                  rd:self._fixed_landmarks_index[i, 0] +
                                  rd] = 1
            dilated_landmarks_sitk = sitk.GetImageFromArray(dilated_landmarks)
            dilated_landmarks_sitk.SetOrigin(self._fixed_im_sitk.GetOrigin())
            dilated_landmarks_sitk.SetSpacing(
                np.array(self._fixed_im_sitk.GetSpacing()))
            sitk.WriteImage(
                sitk.Cast(dilated_landmarks_sitk, sitk.sitkInt8),
                su.address_generator(self._setting, 'DilatedLandmarksIm',
                                     **im_info_fixed))

        if padding:
            min_coordinate_landmark = np.min(self._fixed_landmarks_index,
                                             axis=0)
            max_coordinate_landmark = np.max(self._fixed_landmarks_index,
                                             axis=0)

            pad_before = np.zeros(3, dtype=np.int16)  # xyz order
            pad_after = np.zeros(3, dtype=np.int16)  # xyz order
            for i in range(0, 3):
                # be careful about the xyz or zyx order!
                if min_coordinate_landmark[i] < self._setting['R']:
                    pad_before[
                        i] = self._setting['R'] - min_coordinate_landmark[i]
                if ((np.shape(self._fixed_im)[2 - i] -
                     max_coordinate_landmark[i]) - 1) < self._setting['R']:
                    pad_after[i] = self._setting['R'] - (
                        np.shape(self._fixed_im)[2 - i] -
                        max_coordinate_landmark[i]) + 1

            self._fixed_im = np.pad(
                self._fixed_im,
                ((pad_before[2], pad_after[2]), (pad_before[1], pad_after[1]),
                 (pad_before[0], pad_after[0])),
                'constant',
                constant_values=(self._setting['defaultPixelValue'], ))
            self._moved_im_affine = np.pad(
                self._moved_im_affine,
                ((pad_before[2], pad_after[2]), (pad_before[1], pad_after[1]),
                 (pad_before[0], pad_after[0])),
                'constant',
                constant_values=(self._setting['defaultPixelValue'], ))
            self._fixed_landmarks_index_padded = self._fixed_landmarks_index + pad_before