def resampling(data, spacing=None, requested_im_list=None):
    if spacing is None:
        spacing = [1, 1, 1]
    if requested_im_list is None:
        requested_im_list = ['Im']
    data_exp_dict = [{'data': data, 'deform_exp': '3D_max25_D12'}]
    setting = su.initialize_setting('')
    setting = su.load_setting_from_data_dict(setting, data_exp_dict)

    for type_im in range(len(setting['data'][data]['types'])):
        for cn in setting['data'][data]['CNList']:
            im_info_su = {
                'data': data,
                'type_im': type_im,
                'cn': cn,
                'stage': 1
            }
            for requested_im in requested_im_list:
                if requested_im == 'Im':
                    interpolator = sitk.sitkBSpline
                elif requested_im in ['Lung', 'Torso']:
                    interpolator = sitk.sitkNearestNeighbor
                else:
                    raise ValueError(
                        'interpolator is only defined for ["Im", "Mask", "Torso"] not for '
                        + requested_im)
                im_raw_sitk = sitk.ReadImage(
                    su.address_generator(setting,
                                         'Original' + requested_im + 'Raw',
                                         **im_info_su))
                im_resampled_sitk = ip.resampler_sitk(
                    im_raw_sitk,
                    spacing=spacing,
                    default_pixel_value=setting['data'][data]
                    ['DefaultPixelValue'],
                    interpolator=interpolator,
                    dimension=3)
                sitk.WriteImage(
                    im_resampled_sitk,
                    su.address_generator(setting, 'Original' + requested_im,
                                         **im_info_su))
                print(data + '_TypeIm' + str(type_im) + '_CN' + str(cn) + '_' +
                      requested_im + ' resampled to ' + str(spacing) + ' mm')
Exemple #2
0
def multi_stage(setting, pair_info, overwrite=False):
    """
    :param setting:
    :param pair_info: information of the pair to be registered.
    :param overwrite:
    :return: The output moved images and dvf will be written to the disk.
             1: registration is performed correctly
             2: skip overwriting
             3: the dvf is available from the previous experiment [4, 2, 1]. Then just upsample it.
    """
    stage_list = setting['ImagePyramidSchedule']
    if setting['read_pair_mode'] == 'synthetic':
        deformed_im_ext = pair_info[0].get('deformed_im_ext', None)
        im_info_su = {
            'data': pair_info[0]['data'],
            'deform_exp': pair_info[0]['deform_exp'],
            'type_im': pair_info[0]['type_im'],
            'cn': pair_info[0]['cn'],
            'dsmooth': pair_info[0]['dsmooth'],
            'padto': pair_info[0]['padto'],
            'deformed_im_ext': deformed_im_ext
        }
        moved_im_s0_address = su.address_generator(setting,
                                                   'MovedIm_AG',
                                                   stage=1,
                                                   **im_info_su)
        moved_torso_s1_address = su.address_generator(setting,
                                                      'MovedTorso_AG',
                                                      stage=1,
                                                      **im_info_su)
        moved_lung_s1_address = su.address_generator(setting,
                                                     'MovedLung_AG',
                                                     stage=1,
                                                     **im_info_su)
    else:
        moved_im_s0_address = su.address_generator(setting,
                                                   'MovedIm',
                                                   pair_info=pair_info,
                                                   stage=0,
                                                   stage_list=stage_list)
        moved_torso_s1_address = None
        moved_lung_s1_address = None

    if setting['read_pair_mode'] == 'synthetic':
        if os.path.isfile(moved_im_s0_address) and os.path.isfile(
                moved_torso_s1_address):
            if not overwrite:
                logging.debug('overwrite=False, file ' + moved_im_s0_address +
                              ' already exists, skipping .....')
                return 2
            else:
                logging.debug('overwrite=True, file ' + moved_im_s0_address +
                              ' already exists, but overwriting .....')
    else:
        if os.path.isfile(moved_im_s0_address):
            if not overwrite:
                logging.debug('overwrite=False, file ' + moved_im_s0_address +
                              ' already exists, skipping .....')
                return 2
            else:
                logging.debug('overwrite=True, file ' + moved_im_s0_address +
                              ' already exists, but overwriting .....')

    pair_stage1 = real_pair.Images(
        setting, pair_info, stage=1, padto=setting['PadTo']
        ['stage1'])  # just read the original images without any padding
    pyr = dict()  # pyr: a dictionary of pyramid images
    pyr['fixed_im_s1_sitk'] = pair_stage1.get_fixed_im_sitk()
    pyr['moving_im_s1_sitk'] = pair_stage1.get_moved_im_affine_sitk()
    pyr['fixed_im_s1'] = pair_stage1.get_fixed_im()
    pyr['moving_im_s1'] = pair_stage1.get_moved_im_affine()
    if setting['UseMask']:
        pyr['fixed_mask_s1_sitk'] = pair_stage1.get_fixed_mask_sitk()
        pyr['moving_mask_s1_sitk'] = pair_stage1.get_moved_mask_affine_sitk()
    if setting['read_pair_mode'] == 'real':
        if not (os.path.isdir(
                su.address_generator(setting,
                                     'full_reg_folder',
                                     pair_info=pair_info,
                                     stage_list=stage_list))):
            os.makedirs(
                su.address_generator(setting,
                                     'full_reg_folder',
                                     pair_info=pair_info,
                                     stage_list=stage_list))
    setting['GPUMemory'], setting['NumberOfGPU'] = tfu.client.read_gpu_memory()
    time_before_dvf = time.time()

    # check if DVF is available from the previous experiment [4, 2, 1]. Then just upsample it.
    if stage_list in [[4, 2], [4]]:
        dvf0_address = su.address_generator(setting,
                                            'dvf_s0',
                                            pair_info=pair_info,
                                            stage_list=stage_list)
        chosen_stage = None
        if stage_list == [4, 2]:
            chosen_stage = 2
        elif stage_list == [4]:
            chosen_stage = 4
        if chosen_stage is not None:
            dvf_s_up_address = su.address_generator(setting,
                                                    'dvf_s_up',
                                                    pair_info=pair_info,
                                                    stage=chosen_stage,
                                                    stage_list=[4, 2, 1])
            if os.path.isfile(dvf_s_up_address):
                logging.debug('DVF found from prev exp:' + dvf_s_up_address +
                              ', only performing upsampling')
                dvf_s_up = sitk.ReadImage(dvf_s_up_address)
                dvf0 = ip.resampler_sitk(
                    dvf_s_up,
                    scale=1 / (chosen_stage / 2),
                    im_ref_size=pyr['fixed_im_s1_sitk'].GetSize(),
                    interpolator=sitk.sitkLinear)
                sitk.WriteImage(sitk.Cast(dvf0, sitk.sitkVectorFloat32),
                                dvf0_address)
                return 3

    for i_stage, stage in enumerate(setting['ImagePyramidSchedule']):
        mask_to_zero_stage = setting['network_dict']['stage' +
                                                     str(stage)]['MaskToZero']
        if stage != 1:
            pyr['fixed_im_s' + str(stage) + '_sitk'] = ip.downsampler_gpu(
                pyr['fixed_im_s1_sitk'],
                stage,
                default_pixel_value=setting['data'][
                    pair_info[0]['data']]['DefaultPixelValue'])
            pyr['moving_im_s' + str(stage) + '_sitk'] = ip.downsampler_gpu(
                pyr['moving_im_s1_sitk'],
                stage,
                default_pixel_value=setting['data'][
                    pair_info[1]['data']]['DefaultPixelValue'])
        if setting['UseMask']:
            pyr['fixed_mask_s' + str(stage) + '_sitk'] = ip.resampler_sitk(
                pyr['fixed_mask_s1_sitk'],
                scale=stage,
                im_ref=pyr['fixed_im_s' + str(stage) + '_sitk'],
                default_pixel_value=0,
                interpolator=sitk.sitkNearestNeighbor)
            pyr['moving_mask_s' + str(stage) + '_sitk'] = ip.resampler_sitk(
                pyr['moving_mask_s1_sitk'],
                scale=stage,
                im_ref=pyr['moving_im_s' + str(stage) + '_sitk'],
                default_pixel_value=0,
                interpolator=sitk.sitkNearestNeighbor)

            if setting['WriteMasksForLSTM']:
                # only to be used in sequential training (LSTM)
                if setting['read_pair_mode'] == 'synthetic':
                    fixed_mask_stage_address = su.address_generator(
                        setting,
                        'Deformed' + mask_to_zero_stage,
                        stage=stage,
                        **im_info_su)
                    moving_mask_stage_address = su.address_generator(
                        setting, mask_to_zero_stage, stage=stage, **im_info_su)
                    fixed_im_stage_address = su.address_generator(setting,
                                                                  'DeformedIm',
                                                                  stage=stage,
                                                                  **im_info_su)
                    sitk.WriteImage(
                        sitk.Cast(
                            pyr['fixed_im_s' + str(stage) + '_sitk'],
                            setting['data'][pair_info[1]['data']]
                            ['ImageByte']), fixed_im_stage_address)
                    sitk.WriteImage(pyr['fixed_mask_s' + str(stage) + '_sitk'],
                                    fixed_mask_stage_address)
                    if im_info_su['dsmooth'] != 0 and stage == 4:
                        # not overwirte original images
                        moving_im_stage_address = su.address_generator(
                            setting, 'Im', stage=stage, **im_info_su)
                        sitk.WriteImage(
                            sitk.Cast(
                                pyr['moving_im_s' + str(stage) + '_sitk'],
                                setting['data'][pair_info[1]['data']]
                                ['ImageByte']), moving_im_stage_address)
                        sitk.WriteImage(
                            pyr['moving_mask_s' + str(stage) + '_sitk'],
                            moving_mask_stage_address)
                else:
                    fixed_im_stage_address = su.address_generator(
                        setting, 'Im', stage=stage, **pair_info[0])
                    fixed_mask_stage_address = su.address_generator(
                        setting,
                        mask_to_zero_stage,
                        stage=stage,
                        **pair_info[0])
                    if not os.path.isfile(fixed_im_stage_address):
                        sitk.WriteImage(
                            sitk.Cast(
                                pyr['fixed_im_s' + str(stage) + '_sitk'],
                                setting['data'][pair_info[1]['data']]
                                ['ImageByte']), fixed_im_stage_address)
                    if not os.path.isfile(fixed_mask_stage_address):
                        sitk.WriteImage(
                            pyr['fixed_mask_s' + str(stage) + '_sitk'],
                            fixed_mask_stage_address)
                    if i_stage == 0:
                        moved_im_affine_stage_address = su.address_generator(
                            setting,
                            'MovedImBaseReg',
                            pair_info=pair_info,
                            stage=stage,
                            **pair_info[1])
                        moved_mask_affine_stage_address = su.address_generator(
                            setting,
                            'Moved' + mask_to_zero_stage + 'BaseReg',
                            pair_info=pair_info,
                            stage=stage,
                            **pair_info[1])
                        if not os.path.isfile(moved_im_affine_stage_address):
                            sitk.WriteImage(
                                sitk.Cast(
                                    pyr['moving_im_s' + str(stage) + '_sitk'],
                                    setting['data'][pair_info[1]
                                                    ['data']]['ImageByte']),
                                moved_im_affine_stage_address)
                        if not os.path.isfile(moved_mask_affine_stage_address):
                            sitk.WriteImage(
                                pyr['moving_mask_s' + str(stage) + '_sitk'],
                                moved_mask_affine_stage_address)

        else:
            pyr['fixed_mask_s' + str(stage) + '_sitk'] = None
            pyr['moving_mask_s' + str(stage) + '_sitk'] = None
        input_regnet_moving_mask = None
        if i_stage == 0:
            input_regnet_moving = 'moving_im_s' + str(stage) + '_sitk'
            if setting['UseMask']:
                input_regnet_moving_mask = 'moving_mask_s' + str(
                    stage) + '_sitk'
        else:
            previous_pyramid = setting['ImagePyramidSchedule'][i_stage - 1]
            dvf_composed_previous_up_sitk = 'DVF_s' + str(
                previous_pyramid) + '_composed_up_sitk'
            dvf_composed_previous_sitk = 'DVF_s' + str(
                previous_pyramid) + '_composed_sitk'
            if i_stage == 1:
                pyr[dvf_composed_previous_sitk] = pyr[
                    'DVF_s' +
                    str(setting['ImagePyramidSchedule'][i_stage - 1]) +
                    '_sitk']
            elif i_stage > 1:
                pyr[dvf_composed_previous_sitk] = sitk.Add(
                    pyr['DVF_s' +
                        str(setting['ImagePyramidSchedule'][i_stage - 2]) +
                        '_composed_up_sitk'],
                    pyr['DVF_s' +
                        str(setting['ImagePyramidSchedule'][i_stage - 1]) +
                        '_sitk'])
            pyr[dvf_composed_previous_up_sitk] = ip.upsampler_gpu(
                pyr[dvf_composed_previous_sitk],
                round(previous_pyramid / stage),
                output_shape_3d=pyr['fixed_im_s' + str(stage) +
                                    '_sitk'].GetSize()[::-1],
            )
            if setting['WriteAfterEachStage'] and not setting['WriteNoDVF']:
                sitk.WriteImage(
                    sitk.Cast(pyr[dvf_composed_previous_up_sitk],
                              sitk.sitkVectorFloat32),
                    su.address_generator(setting,
                                         'dvf_s_up',
                                         pair_info=pair_info,
                                         stage=previous_pyramid,
                                         stage_list=stage_list))

            dvf_t = sitk.DisplacementFieldTransform(
                pyr[dvf_composed_previous_up_sitk])
            # after this line DVF_composed_previous_up_sitk is converted to a transform. so we need to load it again.
            pyr['moved_im_s' + str(stage) +
                '_sitk'] = ip.resampler_by_transform(
                    pyr['moving_im_s' + str(stage) + '_sitk'],
                    dvf_t,
                    default_pixel_value=setting['data'][
                        pair_info[1]['data']]['DefaultPixelValue'])
            if setting['UseMask']:
                pyr['moved_mask_s' + str(stage) +
                    '_sitk'] = ip.resampler_by_transform(
                        pyr['moving_mask_s' + str(stage) + '_sitk'],
                        dvf_t,
                        default_pixel_value=0,
                        interpolator=sitk.sitkNearestNeighbor)

            pyr[dvf_composed_previous_up_sitk] = dvf_t.GetDisplacementField()
            if setting['WriteAfterEachStage']:
                if setting['read_pair_mode'] == 'synthetic':
                    moved_im_s_address = su.address_generator(setting,
                                                              'MovedIm_AG',
                                                              stage=stage,
                                                              **im_info_su)
                    moved_mask_s_address = su.address_generator(
                        setting,
                        'Moved' + mask_to_zero_stage + '_AG',
                        stage=stage,
                        **im_info_su)
                else:
                    moved_im_s_address = su.address_generator(
                        setting,
                        'MovedIm',
                        pair_info=pair_info,
                        stage=stage,
                        stage_list=stage_list)
                    moved_mask_s_address = su.address_generator(
                        setting,
                        'Moved' + mask_to_zero_stage,
                        pair_info=pair_info,
                        stage=stage,
                        stage_list=stage_list)

                sitk.WriteImage(
                    sitk.Cast(
                        pyr['moved_im_s' + str(stage) + '_sitk'],
                        setting['data'][pair_info[1]['data']]['ImageByte']),
                    moved_im_s_address)

                if setting['WriteMasksForLSTM']:
                    sitk.WriteImage(pyr['moved_mask_s' + str(stage) + '_sitk'],
                                    moved_mask_s_address)

            input_regnet_moving = 'moved_im_s' + str(stage) + '_sitk'
            if setting['UseMask']:
                input_regnet_moving_mask = 'moved_mask_s' + str(
                    stage) + '_sitk'

        pyr['DVF_s' + str(stage)] = np.zeros(
            np.r_[pyr['fixed_im_s' + str(stage) + '_sitk'].GetSize()[::-1], 3],
            dtype=np.float64)
        if setting['network_dict']['stage'+str(stage)]['R'] == 'Auto' and \
                setting['network_dict']['stage'+str(stage)]['Ry'] == 'Auto':
            current_network_name = setting['network_dict'][
                'stage' + str(stage)]['NetworkDesign']
            r_out_erode_default = setting['network_dict'][
                'stage' + str(stage)]['Ry_erode']
            r_in, r_out, r_out_erode = network.utils.find_optimal_radius(
                pyr['fixed_im_s' + str(stage) + '_sitk'],
                current_network_name,
                r_out_erode_default,
                gpu_memory=setting['GPUMemory'],
                number_of_gpu=setting['NumberOfGPU'])

        else:
            r_in = setting['network_dict']['stage' + str(stage)][
                'R']  # Radius of normal resolution patch size. Total size is (2*R +1)
            r_out = setting['network_dict']['stage' + str(stage)][
                'Ry']  # Radius of output. Total size is (2*Ry +1)
            r_out_erode = setting['network_dict']['stage' + str(stage)][
                'Ry_erode']  # at the test time, sometimes there are some problems at the border

        logging.debug(
            'stage' + str(stage) + ' ,' + pair_info[0]['data'] +
            ', CN{}, ImType{}, Size={}'.format(
                pair_info[0]['cn'], pair_info[0]['type_im'], pyr[
                    'fixed_im_s' + str(stage) + '_sitk'].GetSize()[::-1]) +
            ', ' +
            setting['network_dict']['stage' + str(stage)]['NetworkDesign'] +
            ': r_in:{}, r_out:{}, r_out_erode:{}'.format(
                r_in, r_out, r_out_erode))
        pair_pyramid = real_pair.Images(
            setting,
            pair_info,
            stage=stage,
            fixed_im_sitk=pyr['fixed_im_s' + str(stage) + '_sitk'],
            moved_im_affine_sitk=pyr[input_regnet_moving],
            fixed_mask_sitk=pyr['fixed_mask_s' + str(stage) + '_sitk'],
            moved_mask_affine_sitk=pyr[input_regnet_moving_mask],
            padto=setting['PadTo']['stage' + str(stage)],
            r_in=r_in,
            r_out=r_out,
            r_out_erode=r_out_erode)

        # building and loading network
        tf.reset_default_graph()
        images_tf = tf.placeholder(
            tf.float32,
            shape=[None, 2 * r_in + 1, 2 * r_in + 1, 2 * r_in + 1, 2],
            name="Images")
        bn_training = tf.placeholder(tf.bool, name='bn_training')
        dvf_tf = getattr(
            getattr(
                network, setting['network_dict']['stage' +
                                                 str(stage)]['NetworkDesign']),
            'network')(images_tf, bn_training)
        logging.debug(' Total number of variables %s' % (np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ])))
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
        sess = tf.Session()
        saver.restore(
            sess,
            su.address_generator(
                setting,
                'saved_model_with_step',
                current_experiment=setting['network_dict'][
                    'stage' + str(stage)]['NetworkLoad'],
                step=setting['network_dict']['stage' +
                                             str(stage)]['GlobalStepLoad']))
        while not pair_pyramid.get_sweep_completed():
            # The pyr[DVF_S] is the numpy DVF which will be filled. the dvf_np is an output
            # patch from the network. We control the spatial location of both dvf in this function
            batch_im, win_center, win_r_before, win_r_after, predicted_begin, predicted_end = pair_pyramid.next_sweep_patch(
            )
            time_before_gpu = time.time()
            [dvf_np] = sess.run([dvf_tf],
                                feed_dict={
                                    images_tf: batch_im,
                                    bn_training: 0
                                })
            time_after_gpu = time.time()
            logging.debug('GPU: ' + pair_info[0]['data'] +
                          ', CN{} center={} is done in {:.2f}s '.format(
                              pair_info[0]['cn'], win_center, time_after_gpu -
                              time_before_gpu))

            pyr['DVF_s'+str(stage)][win_center[0] - win_r_before[0]: win_center[0] + win_r_after[0],
                                    win_center[1] - win_r_before[1]: win_center[1] + win_r_after[1],
                                    win_center[2] - win_r_before[2]: win_center[2] + win_r_after[2], :] = \
                dvf_np[0, predicted_begin[0]:predicted_end[0], predicted_begin[1]:predicted_end[1], predicted_begin[2]:predicted_end[2], :]
            # rescaling dvf based on the voxel spacing:
            spacing_ref = [1.0 * stage for _ in range(3)]
            spacing_current = pyr['fixed_im_s' + str(stage) +
                                  '_sitk'].GetSpacing()
            for dim in range(3):
                pyr['DVF_s' + str(stage)][:, :, :, dim] = pyr['DVF_s' + str(
                    stage)][:, :, :,
                            dim] * spacing_current[dim] / spacing_ref[dim]

        pyr['DVF_s' + str(stage) + '_sitk'] = ip.array_to_sitk(
            pyr['DVF_s' + str(stage)],
            im_ref=pyr['fixed_im_s' + str(stage) + '_sitk'],
            is_vector=True)

        if i_stage == (len(setting['ImagePyramidSchedule']) - 1):
            # when all stages are finished, final dvf and moved image are written
            dvf_composed_final_sitk = 'DVF_s' + str(stage) + '_composed_sitk'
            if len(setting['ImagePyramidSchedule']) == 1:
                # need to upsample in the case that last stage is not 1
                if stage == 1:
                    pyr[dvf_composed_final_sitk] = pyr['DVF_s' + str(stage) +
                                                       '_sitk']
                else:
                    pyr[dvf_composed_final_sitk] = ip.resampler_sitk(
                        pyr['DVF_s' + str(stage) + '_sitk'],
                        scale=1 / stage,
                        im_ref_size=pyr['fixed_im_s1_sitk'].GetSize(),
                        interpolator=sitk.sitkLinear)
            else:
                pyr[dvf_composed_final_sitk] = sitk.Add(
                    pyr['DVF_s' + str(setting['ImagePyramidSchedule'][-2]) +
                        '_composed_up_sitk'],
                    pyr['DVF_s' + str(stage) + '_sitk'])
                if stage != 1:
                    pyr[dvf_composed_final_sitk] = ip.resampler_sitk(
                        pyr[dvf_composed_final_sitk],
                        scale=1 / stage,
                        im_ref_size=pyr['fixed_im_s1_sitk'].GetSize(),
                        interpolator=sitk.sitkLinear)
            if not setting['WriteNoDVF']:
                sitk.WriteImage(
                    sitk.Cast(pyr[dvf_composed_final_sitk],
                              sitk.sitkVectorFloat32),
                    su.address_generator(setting,
                                         'dvf_s0',
                                         pair_info=pair_info,
                                         stage_list=stage_list))
            dvf_t = sitk.DisplacementFieldTransform(
                pyr[dvf_composed_final_sitk])
            pyr['moved_im_s0_sitk'] = ip.resampler_by_transform(
                pyr['moving_im_s1_sitk'],
                dvf_t,
                default_pixel_value=setting['data'][
                    pair_info[1]['data']]['DefaultPixelValue'])
            sitk.WriteImage(
                sitk.Cast(pyr['moved_im_s0_sitk'],
                          setting['data'][pair_info[1]['data']]['ImageByte']),
                moved_im_s0_address)

            if setting['WriteMasksForLSTM']:
                mask_to_zero_stage = setting['network_dict'][
                    'stage' + str(stage)]['MaskToZero']
                if setting['read_pair_mode'] == 'synthetic':
                    moving_mask_sitk = sitk.ReadImage(
                        su.address_generator(setting,
                                             mask_to_zero_stage,
                                             stage=1,
                                             **im_info_su))
                    moved_mask_stage1 = ip.resampler_by_transform(
                        moving_mask_sitk,
                        dvf_t,
                        default_pixel_value=0,
                        interpolator=sitk.sitkNearestNeighbor)
                    sitk.WriteImage(
                        moved_mask_stage1,
                        su.address_generator(setting,
                                             'Moved' + mask_to_zero_stage +
                                             '_AG',
                                             stage=1,
                                             **im_info_su))
                    logging.debug('writing ' +
                                  su.address_generator(setting,
                                                       'Moved' +
                                                       mask_to_zero_stage +
                                                       '_AG',
                                                       stage=1,
                                                       **im_info_su))

    time_after_dvf = time.time()
    logging.debug(
        pair_info[0]['data'] + ', CN{}, ImType{} is done in {:.2f}s '.format(
            pair_info[0]['cn'], pair_info[0]['type_im'], time_after_dvf -
            time_before_dvf))

    return 0
Exemple #3
0
def check_downsampled_base_reg(setting,
                               stage,
                               base_reg=None,
                               pair_info=None,
                               mask_to_zero_stage=None):
    if 'DownSamplingByGPU' not in setting.keys():
        setting['DownSamplingByGPU'] = False
    im_info_moving = pair_info[1]
    im_list_downsample = [
        {
            'Image':
            'MovedImBaseReg',
            'interpolator':
            'BSpline',
            'DefaultPixelValue':
            setting['data'][im_info_moving['data']]['DefaultPixelValue'],
            'ImageByte':
            setting['data'][im_info_moving['data']]['ImageByte']
        },
        {
            'Image': 'Moved' + mask_to_zero_stage + 'BaseReg',
            'interpolator': 'NearestNeighbor',
            'DefaultPixelValue': 0,
            'ImageByte': sitk.sitkInt8
        },
    ]
    #
    # im_stage_address = su.address_generator(setting, 'MovedImBaseReg', pair_info=pair_info,
    #                          stage=stage, base_reg=base_reg, **im_info_moving)
    # mask_stage_address = su.address_generator(setting, 'Moved' + mask_to_zero_stage + 'BaseReg', pair_info=pair_info,
    #                          stage=stage, base_reg=base_reg, **im_info_moving)

    for im_dict in im_list_downsample:
        im_stage_address = su.address_generator(setting,
                                                im_dict['Image'],
                                                stage=stage,
                                                pair_info=pair_info,
                                                base_reg=base_reg,
                                                **im_info_moving)

        if not os.path.isfile(im_stage_address):
            im_s1_sitk = sitk.ReadImage(
                su.address_generator(setting,
                                     im_dict['Image'],
                                     stage=1,
                                     pair_info=pair_info,
                                     base_reg=base_reg,
                                     **im_info_moving))
            if setting['DownSamplingByGPU'] and im_dict[
                    'Image'] == 'MovedImBaseReg':
                im_s1 = sitk.GetArrayFromImage(im_s1_sitk)
                im_stage = ip.downsampler_gpu(
                    im_s1,
                    stage,
                    normalize_kernel=True,
                    default_pixel_value=im_dict['DefaultPixelValue'])

                im_stage_sitk = ip.array_to_sitk(
                    im_stage,
                    origin=im_s1_sitk.GetOrigin(),
                    spacing=tuple(i * stage for i in im_s1_sitk.GetSpacing()),
                    direction=im_s1_sitk.GetDirection())
            else:
                if im_dict['interpolator'] == 'NearestNeighbor':
                    interpolator = sitk.sitkNearestNeighbor
                elif im_dict['interpolator'] == 'BSpline':
                    interpolator = sitk.sitkBSpline
                else:
                    raise ValueError(
                        "interpolator should be in ['NearestNeighbor', 'BSpline']"
                    )

                if im_dict['Image'] in [
                        'MovedTorsoBaseReg', 'MovedLungBaseReg'
                ]:
                    im_ref_sitk = sitk.ReadImage(
                        su.address_generator(setting,
                                             'MovedImBaseReg',
                                             pair_info=pair_info,
                                             stage=stage,
                                             base_reg=base_reg,
                                             **im_info_moving))
                else:
                    im_ref_sitk = None
                im_stage_sitk = ip.resampler_sitk(
                    im_s1_sitk,
                    scale=stage,
                    im_ref=im_ref_sitk,
                    default_pixel_value=im_dict['DefaultPixelValue'],
                    interpolator=interpolator)

            sitk.WriteImage(sitk.Cast(im_stage_sitk, im_dict['ImageByte']),
                            im_stage_address)
    return 0
Exemple #4
0
def check_downsampled_images(setting, im_info, stage):
    padto = im_info['padto']
    im_info_su = {
        'data': im_info['data'],
        'deform_exp': im_info['deform_exp'],
        'type_im': im_info['type_im'],
        'cn': im_info['cn'],
        'dsmooth': im_info['dsmooth']
    }

    im_list_downsample = [{
        'Image':
        'Im',
        'interpolator':
        'BSpline',
        'DefaultPixelValue':
        setting['data'][im_info['data']]['DefaultPixelValue'],
        'ImageByte':
        setting['data'][im_info['data']]['ImageByte']
    }, {
        'Image': 'Lung',
        'interpolator': 'NearestNeighbor',
        'DefaultPixelValue': 0,
        'ImageByte': sitk.sitkInt8
    }, {
        'Image': 'Torso',
        'interpolator': 'NearestNeighbor',
        'DefaultPixelValue': 0,
        'ImageByte': sitk.sitkInt8
    }, {
        'Image':
        'DeformedIm',
        'interpolator':
        'BSpline',
        'DefaultPixelValue':
        setting['data'][im_info['data']]['DefaultPixelValue'],
        'ImageByte':
        setting['data'][im_info['data']]['ImageByte']
    }, {
        'Image': 'DeformedLung',
        'interpolator': 'NearestNeighbor',
        'DefaultPixelValue': 0,
        'ImageByte': sitk.sitkInt8
    }, {
        'Image': 'DeformedTorso',
        'interpolator': 'NearestNeighbor',
        'DefaultPixelValue': 0,
        'ImageByte': sitk.sitkInt8
    }]

    for im_dict in im_list_downsample:
        im_stage_address = su.address_generator(
            setting,
            im_dict['Image'],
            stage=stage,
            padto=padto,
            deformed_im_ext=im_info['deformed_im_ext'],
            **im_info_su)
        # remove the second condition later, that was for fixing some images ['Im', 'Lung'] --> I removed

        if not os.path.isfile(im_stage_address):
            im_s1_sitk = sitk.ReadImage(
                su.address_generator(
                    setting,
                    im_dict['Image'],
                    stage=1,
                    padto=padto,
                    deformed_im_ext=im_info['deformed_im_ext'],
                    **im_info_su))
            if setting['DownSamplingByGPU'] and im_dict[
                    'Image'] == 'OriginalIm':
                im_s1 = sitk.GetArrayFromImage(im_s1_sitk)
                im_stage = ip.downsampler_gpu(
                    im_s1,
                    stage,
                    normalize_kernel=True,
                    default_pixel_value=im_dict['DefaultPixelValue'])

                im_stage_sitk = ip.array_to_sitk(
                    im_stage,
                    origin=im_s1_sitk.GetOrigin(),
                    spacing=tuple(i * stage for i in im_s1_sitk.GetSpacing()),
                    direction=im_s1_sitk.GetDirection())
            else:
                if im_dict['interpolator'] == 'NearestNeighbor':
                    interpolator = sitk.sitkNearestNeighbor
                elif im_dict['interpolator'] == 'BSpline':
                    interpolator = sitk.sitkBSpline
                else:
                    raise ValueError(
                        "interpolator should be in ['NearestNeighbor', 'BSpline']"
                    )

                if im_dict['Image'] in ['Torso', 'Lung']:
                    im_ref_sitk = sitk.ReadImage(
                        su.address_generator(setting,
                                             'Im',
                                             stage=stage,
                                             padto=padto,
                                             **im_info_su))
                else:
                    im_ref_sitk = None
                im_stage_sitk = ip.resampler_sitk(
                    im_s1_sitk,
                    scale=stage,
                    im_ref=im_ref_sitk,
                    default_pixel_value=im_dict['DefaultPixelValue'],
                    interpolator=interpolator)

                # for debugging
                # sitk.WriteImage(sitk.Cast(im_stage_sitk, im_dict['ImageByte']),
                #                 su.address_generator(setting, im_dict['Image'], stage=stage, padto=None, **im_info_su))
                if padto is not None:
                    dim_im = setting['Dim']
                    pad_before = np.zeros(dim_im, dtype=np.int)
                    pad_after = np.zeros(dim_im, dtype=np.int)
                    im_size = np.array(im_stage_sitk.GetSize())
                    extra_to_pad = padto - im_size
                    for d in range(dim_im):
                        if extra_to_pad[d] < 0:
                            raise ValueError(
                                'size of the padto=' + str(padto) +
                                ' should be smaller than the size of the image {}'
                                .format(im_size))
                        elif extra_to_pad[d] == 0:
                            pad_before[d] = 0
                            pad_after[d] = 0
                        else:
                            if extra_to_pad[d] % 2 == 0:
                                pad_before[d] = np.int(extra_to_pad[d] / 2)
                                pad_after[d] = np.int(extra_to_pad[d] / 2)
                            else:
                                pad_before[d] = np.floor(extra_to_pad[d] / 2)
                                pad_after[d] = np.floor(
                                    extra_to_pad[d] / 2) + 1

                    pad_before = [int(p) for p in pad_before]
                    pad_after = [int(p) for p in pad_after]
                    im_stage_sitk = sitk.ConstantPad(
                        im_stage_sitk,
                        [pad_before[0], pad_before[1], pad_before[2]],
                        [pad_after[0], pad_after[1], pad_after[2]],
                        constant=im_dict['DefaultPixelValue'],
                    )

            sitk.WriteImage(sitk.Cast(im_stage_sitk, im_dict['ImageByte']),
                            im_stage_address)
    return 0