def landmarks_from_dvf_old(setting, IN_test_list): # %%------------------------------------------- Setting of generating synthetic DVFs------------------------------------------ saved_file = su.address_generator(setting, 'landmarks_file') if os.path.isfile(saved_file): raise ValueError('cannot overwrite, please change the name of the pickle file: ' + saved_file) # %%------------------------------------------------------ running --------------------------------------------------------- landmarks = [{} for _ in range(22)] for IN in IN_test_list: image_pair = real_pair.Images(IN, setting=setting) image_pair.prepare_for_landmarks(padding=False) landmarks[IN]['FixedLandmarksWorld'] = image_pair._fixed_landmarks_world.copy() landmarks[IN]['MovingLandmarksWorld'] = image_pair._moving_landmarks_world.copy() landmarks[IN]['FixedAfterAffineLandmarksWorld'] = image_pair._fixed_after_affine_landmarks_world.copy() landmarks[IN]['DVFAffine'] = image_pair._dvf_affine.copy() landmarks[IN]['DVF_nonrigidGroundTruth'] = image_pair._moving_landmarks_world - image_pair._fixed_after_affine_landmarks_world landmarks[IN]['FixedLandmarksIndex'] = image_pair._fixed_landmarks_index dvf_s0_sitk = sitk.ReadImage(su.address_generator(setting, 'dvf_s0', cn=IN)) dvf_s0 = sitk.GetArrayFromImage(dvf_s0_sitk) landmarks[IN]['DVFRegNet'] = np.stack([dvf_s0[image_pair._fixed_landmarks_index[i, 2], image_pair._fixed_landmarks_index[i, 1], image_pair._fixed_landmarks_index[i, 0]] for i in range(len(image_pair._fixed_landmarks_index))]) print('CPU: IN = {} is done in '.format(IN)) with open(saved_file, 'wb') as f: pickle.dump(landmarks, f)
def landmark_info(setting, pair_info, base_reg='Affine'): """ extract landmark information. Be very careful about the order: :param setting: :param pair_info: :return: 'FixedLandmarksWorld': xyz order 'MovingLandmarksWorld': xyz order 'FixedAfterAffineLandmarksWorld': xyz order 'FixedLandmarksIndex': xyz order """ pair = real_pair.Images(setting, pair_info, stage=1) pair.prepare_for_landmarks(padding=False) current_landmark = { 'setting': setting, 'pair_info': pair_info, 'FixedLandmarksWorld': pair._fixed_landmarks_world.copy(), 'MovingLandmarksWorld': pair._moving_landmarks_world.copy(), 'FixedAfter' + base_reg + 'LandmarksWorld': pair._fixed_after_affine_landmarks_world.copy(), 'DVF' + base_reg: pair._dvf_affine.copy(), 'DVF_nonrigidGroundTruth': pair._moving_landmarks_world - pair._fixed_after_affine_landmarks_world, 'FixedLandmarksIndex': pair._fixed_landmarks_index.copy(), } return current_landmark
def landmarks_from_dvf(setting, pair_info): stage_list = setting['ImagePyramidSchedule'] pair = real_pair.Images(setting, pair_info, stage=1) pair.prepare_for_landmarks(padding=False) dvf_s0 = sitk.GetArrayFromImage(sitk.ReadImage( su.address_generator(setting, 'dvf_s0', pair_info=pair_info, stage_list=stage_list))) current_landmark = {'setting': setting, 'pair_info': pair_info, 'FixedLandmarksWorld': pair._fixed_landmarks_world.copy(), 'MovingLandmarksWorld': pair._moving_landmarks_world.copy(), 'FixedAfterAffineLandmarksWorld': pair._fixed_after_affine_landmarks_world.copy(), 'DVFAffine': pair._dvf_affine.copy(), 'DVF_nonrigidGroundTruth': pair._moving_landmarks_world - pair._fixed_after_affine_landmarks_world, 'FixedLandmarksIndex': pair._fixed_landmarks_index.copy(), 'DVFRegNet': np.stack([dvf_s0[pair._fixed_landmarks_index[i, 2], pair._fixed_landmarks_index[i, 1], pair._fixed_landmarks_index[i, 0]] for i in range(len(pair._fixed_landmarks_index))]) } return current_landmark
def multi_stage(setting, pair_info, overwrite=False): """ :param setting: :param pair_info: information of the pair to be registered. :param overwrite: :return: The output moved images and dvf will be written to the disk. 1: registration is performed correctly 2: skip overwriting 3: the dvf is available from the previous experiment [4, 2, 1]. Then just upsample it. """ stage_list = setting['ImagePyramidSchedule'] if setting['read_pair_mode'] == 'synthetic': deformed_im_ext = pair_info[0].get('deformed_im_ext', None) im_info_su = { 'data': pair_info[0]['data'], 'deform_exp': pair_info[0]['deform_exp'], 'type_im': pair_info[0]['type_im'], 'cn': pair_info[0]['cn'], 'dsmooth': pair_info[0]['dsmooth'], 'padto': pair_info[0]['padto'], 'deformed_im_ext': deformed_im_ext } moved_im_s0_address = su.address_generator(setting, 'MovedIm_AG', stage=1, **im_info_su) moved_torso_s1_address = su.address_generator(setting, 'MovedTorso_AG', stage=1, **im_info_su) moved_lung_s1_address = su.address_generator(setting, 'MovedLung_AG', stage=1, **im_info_su) else: moved_im_s0_address = su.address_generator(setting, 'MovedIm', pair_info=pair_info, stage=0, stage_list=stage_list) moved_torso_s1_address = None moved_lung_s1_address = None if setting['read_pair_mode'] == 'synthetic': if os.path.isfile(moved_im_s0_address) and os.path.isfile( moved_torso_s1_address): if not overwrite: logging.debug('overwrite=False, file ' + moved_im_s0_address + ' already exists, skipping .....') return 2 else: logging.debug('overwrite=True, file ' + moved_im_s0_address + ' already exists, but overwriting .....') else: if os.path.isfile(moved_im_s0_address): if not overwrite: logging.debug('overwrite=False, file ' + moved_im_s0_address + ' already exists, skipping .....') return 2 else: logging.debug('overwrite=True, file ' + moved_im_s0_address + ' already exists, but overwriting .....') pair_stage1 = real_pair.Images( setting, pair_info, stage=1, padto=setting['PadTo'] ['stage1']) # just read the original images without any padding pyr = dict() # pyr: a dictionary of pyramid images pyr['fixed_im_s1_sitk'] = pair_stage1.get_fixed_im_sitk() pyr['moving_im_s1_sitk'] = pair_stage1.get_moved_im_affine_sitk() pyr['fixed_im_s1'] = pair_stage1.get_fixed_im() pyr['moving_im_s1'] = pair_stage1.get_moved_im_affine() if setting['UseMask']: pyr['fixed_mask_s1_sitk'] = pair_stage1.get_fixed_mask_sitk() pyr['moving_mask_s1_sitk'] = pair_stage1.get_moved_mask_affine_sitk() if setting['read_pair_mode'] == 'real': if not (os.path.isdir( su.address_generator(setting, 'full_reg_folder', pair_info=pair_info, stage_list=stage_list))): os.makedirs( su.address_generator(setting, 'full_reg_folder', pair_info=pair_info, stage_list=stage_list)) setting['GPUMemory'], setting['NumberOfGPU'] = tfu.client.read_gpu_memory() time_before_dvf = time.time() # check if DVF is available from the previous experiment [4, 2, 1]. Then just upsample it. if stage_list in [[4, 2], [4]]: dvf0_address = su.address_generator(setting, 'dvf_s0', pair_info=pair_info, stage_list=stage_list) chosen_stage = None if stage_list == [4, 2]: chosen_stage = 2 elif stage_list == [4]: chosen_stage = 4 if chosen_stage is not None: dvf_s_up_address = su.address_generator(setting, 'dvf_s_up', pair_info=pair_info, stage=chosen_stage, stage_list=[4, 2, 1]) if os.path.isfile(dvf_s_up_address): logging.debug('DVF found from prev exp:' + dvf_s_up_address + ', only performing upsampling') dvf_s_up = sitk.ReadImage(dvf_s_up_address) dvf0 = ip.resampler_sitk( dvf_s_up, scale=1 / (chosen_stage / 2), im_ref_size=pyr['fixed_im_s1_sitk'].GetSize(), interpolator=sitk.sitkLinear) sitk.WriteImage(sitk.Cast(dvf0, sitk.sitkVectorFloat32), dvf0_address) return 3 for i_stage, stage in enumerate(setting['ImagePyramidSchedule']): mask_to_zero_stage = setting['network_dict']['stage' + str(stage)]['MaskToZero'] if stage != 1: pyr['fixed_im_s' + str(stage) + '_sitk'] = ip.downsampler_gpu( pyr['fixed_im_s1_sitk'], stage, default_pixel_value=setting['data'][ pair_info[0]['data']]['DefaultPixelValue']) pyr['moving_im_s' + str(stage) + '_sitk'] = ip.downsampler_gpu( pyr['moving_im_s1_sitk'], stage, default_pixel_value=setting['data'][ pair_info[1]['data']]['DefaultPixelValue']) if setting['UseMask']: pyr['fixed_mask_s' + str(stage) + '_sitk'] = ip.resampler_sitk( pyr['fixed_mask_s1_sitk'], scale=stage, im_ref=pyr['fixed_im_s' + str(stage) + '_sitk'], default_pixel_value=0, interpolator=sitk.sitkNearestNeighbor) pyr['moving_mask_s' + str(stage) + '_sitk'] = ip.resampler_sitk( pyr['moving_mask_s1_sitk'], scale=stage, im_ref=pyr['moving_im_s' + str(stage) + '_sitk'], default_pixel_value=0, interpolator=sitk.sitkNearestNeighbor) if setting['WriteMasksForLSTM']: # only to be used in sequential training (LSTM) if setting['read_pair_mode'] == 'synthetic': fixed_mask_stage_address = su.address_generator( setting, 'Deformed' + mask_to_zero_stage, stage=stage, **im_info_su) moving_mask_stage_address = su.address_generator( setting, mask_to_zero_stage, stage=stage, **im_info_su) fixed_im_stage_address = su.address_generator(setting, 'DeformedIm', stage=stage, **im_info_su) sitk.WriteImage( sitk.Cast( pyr['fixed_im_s' + str(stage) + '_sitk'], setting['data'][pair_info[1]['data']] ['ImageByte']), fixed_im_stage_address) sitk.WriteImage(pyr['fixed_mask_s' + str(stage) + '_sitk'], fixed_mask_stage_address) if im_info_su['dsmooth'] != 0 and stage == 4: # not overwirte original images moving_im_stage_address = su.address_generator( setting, 'Im', stage=stage, **im_info_su) sitk.WriteImage( sitk.Cast( pyr['moving_im_s' + str(stage) + '_sitk'], setting['data'][pair_info[1]['data']] ['ImageByte']), moving_im_stage_address) sitk.WriteImage( pyr['moving_mask_s' + str(stage) + '_sitk'], moving_mask_stage_address) else: fixed_im_stage_address = su.address_generator( setting, 'Im', stage=stage, **pair_info[0]) fixed_mask_stage_address = su.address_generator( setting, mask_to_zero_stage, stage=stage, **pair_info[0]) if not os.path.isfile(fixed_im_stage_address): sitk.WriteImage( sitk.Cast( pyr['fixed_im_s' + str(stage) + '_sitk'], setting['data'][pair_info[1]['data']] ['ImageByte']), fixed_im_stage_address) if not os.path.isfile(fixed_mask_stage_address): sitk.WriteImage( pyr['fixed_mask_s' + str(stage) + '_sitk'], fixed_mask_stage_address) if i_stage == 0: moved_im_affine_stage_address = su.address_generator( setting, 'MovedImBaseReg', pair_info=pair_info, stage=stage, **pair_info[1]) moved_mask_affine_stage_address = su.address_generator( setting, 'Moved' + mask_to_zero_stage + 'BaseReg', pair_info=pair_info, stage=stage, **pair_info[1]) if not os.path.isfile(moved_im_affine_stage_address): sitk.WriteImage( sitk.Cast( pyr['moving_im_s' + str(stage) + '_sitk'], setting['data'][pair_info[1] ['data']]['ImageByte']), moved_im_affine_stage_address) if not os.path.isfile(moved_mask_affine_stage_address): sitk.WriteImage( pyr['moving_mask_s' + str(stage) + '_sitk'], moved_mask_affine_stage_address) else: pyr['fixed_mask_s' + str(stage) + '_sitk'] = None pyr['moving_mask_s' + str(stage) + '_sitk'] = None input_regnet_moving_mask = None if i_stage == 0: input_regnet_moving = 'moving_im_s' + str(stage) + '_sitk' if setting['UseMask']: input_regnet_moving_mask = 'moving_mask_s' + str( stage) + '_sitk' else: previous_pyramid = setting['ImagePyramidSchedule'][i_stage - 1] dvf_composed_previous_up_sitk = 'DVF_s' + str( previous_pyramid) + '_composed_up_sitk' dvf_composed_previous_sitk = 'DVF_s' + str( previous_pyramid) + '_composed_sitk' if i_stage == 1: pyr[dvf_composed_previous_sitk] = pyr[ 'DVF_s' + str(setting['ImagePyramidSchedule'][i_stage - 1]) + '_sitk'] elif i_stage > 1: pyr[dvf_composed_previous_sitk] = sitk.Add( pyr['DVF_s' + str(setting['ImagePyramidSchedule'][i_stage - 2]) + '_composed_up_sitk'], pyr['DVF_s' + str(setting['ImagePyramidSchedule'][i_stage - 1]) + '_sitk']) pyr[dvf_composed_previous_up_sitk] = ip.upsampler_gpu( pyr[dvf_composed_previous_sitk], round(previous_pyramid / stage), output_shape_3d=pyr['fixed_im_s' + str(stage) + '_sitk'].GetSize()[::-1], ) if setting['WriteAfterEachStage'] and not setting['WriteNoDVF']: sitk.WriteImage( sitk.Cast(pyr[dvf_composed_previous_up_sitk], sitk.sitkVectorFloat32), su.address_generator(setting, 'dvf_s_up', pair_info=pair_info, stage=previous_pyramid, stage_list=stage_list)) dvf_t = sitk.DisplacementFieldTransform( pyr[dvf_composed_previous_up_sitk]) # after this line DVF_composed_previous_up_sitk is converted to a transform. so we need to load it again. pyr['moved_im_s' + str(stage) + '_sitk'] = ip.resampler_by_transform( pyr['moving_im_s' + str(stage) + '_sitk'], dvf_t, default_pixel_value=setting['data'][ pair_info[1]['data']]['DefaultPixelValue']) if setting['UseMask']: pyr['moved_mask_s' + str(stage) + '_sitk'] = ip.resampler_by_transform( pyr['moving_mask_s' + str(stage) + '_sitk'], dvf_t, default_pixel_value=0, interpolator=sitk.sitkNearestNeighbor) pyr[dvf_composed_previous_up_sitk] = dvf_t.GetDisplacementField() if setting['WriteAfterEachStage']: if setting['read_pair_mode'] == 'synthetic': moved_im_s_address = su.address_generator(setting, 'MovedIm_AG', stage=stage, **im_info_su) moved_mask_s_address = su.address_generator( setting, 'Moved' + mask_to_zero_stage + '_AG', stage=stage, **im_info_su) else: moved_im_s_address = su.address_generator( setting, 'MovedIm', pair_info=pair_info, stage=stage, stage_list=stage_list) moved_mask_s_address = su.address_generator( setting, 'Moved' + mask_to_zero_stage, pair_info=pair_info, stage=stage, stage_list=stage_list) sitk.WriteImage( sitk.Cast( pyr['moved_im_s' + str(stage) + '_sitk'], setting['data'][pair_info[1]['data']]['ImageByte']), moved_im_s_address) if setting['WriteMasksForLSTM']: sitk.WriteImage(pyr['moved_mask_s' + str(stage) + '_sitk'], moved_mask_s_address) input_regnet_moving = 'moved_im_s' + str(stage) + '_sitk' if setting['UseMask']: input_regnet_moving_mask = 'moved_mask_s' + str( stage) + '_sitk' pyr['DVF_s' + str(stage)] = np.zeros( np.r_[pyr['fixed_im_s' + str(stage) + '_sitk'].GetSize()[::-1], 3], dtype=np.float64) if setting['network_dict']['stage'+str(stage)]['R'] == 'Auto' and \ setting['network_dict']['stage'+str(stage)]['Ry'] == 'Auto': current_network_name = setting['network_dict'][ 'stage' + str(stage)]['NetworkDesign'] r_out_erode_default = setting['network_dict'][ 'stage' + str(stage)]['Ry_erode'] r_in, r_out, r_out_erode = network.utils.find_optimal_radius( pyr['fixed_im_s' + str(stage) + '_sitk'], current_network_name, r_out_erode_default, gpu_memory=setting['GPUMemory'], number_of_gpu=setting['NumberOfGPU']) else: r_in = setting['network_dict']['stage' + str(stage)][ 'R'] # Radius of normal resolution patch size. Total size is (2*R +1) r_out = setting['network_dict']['stage' + str(stage)][ 'Ry'] # Radius of output. Total size is (2*Ry +1) r_out_erode = setting['network_dict']['stage' + str(stage)][ 'Ry_erode'] # at the test time, sometimes there are some problems at the border logging.debug( 'stage' + str(stage) + ' ,' + pair_info[0]['data'] + ', CN{}, ImType{}, Size={}'.format( pair_info[0]['cn'], pair_info[0]['type_im'], pyr[ 'fixed_im_s' + str(stage) + '_sitk'].GetSize()[::-1]) + ', ' + setting['network_dict']['stage' + str(stage)]['NetworkDesign'] + ': r_in:{}, r_out:{}, r_out_erode:{}'.format( r_in, r_out, r_out_erode)) pair_pyramid = real_pair.Images( setting, pair_info, stage=stage, fixed_im_sitk=pyr['fixed_im_s' + str(stage) + '_sitk'], moved_im_affine_sitk=pyr[input_regnet_moving], fixed_mask_sitk=pyr['fixed_mask_s' + str(stage) + '_sitk'], moved_mask_affine_sitk=pyr[input_regnet_moving_mask], padto=setting['PadTo']['stage' + str(stage)], r_in=r_in, r_out=r_out, r_out_erode=r_out_erode) # building and loading network tf.reset_default_graph() images_tf = tf.placeholder( tf.float32, shape=[None, 2 * r_in + 1, 2 * r_in + 1, 2 * r_in + 1, 2], name="Images") bn_training = tf.placeholder(tf.bool, name='bn_training') dvf_tf = getattr( getattr( network, setting['network_dict']['stage' + str(stage)]['NetworkDesign']), 'network')(images_tf, bn_training) logging.debug(' Total number of variables %s' % (np.sum([ np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() ]))) saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) sess = tf.Session() saver.restore( sess, su.address_generator( setting, 'saved_model_with_step', current_experiment=setting['network_dict'][ 'stage' + str(stage)]['NetworkLoad'], step=setting['network_dict']['stage' + str(stage)]['GlobalStepLoad'])) while not pair_pyramid.get_sweep_completed(): # The pyr[DVF_S] is the numpy DVF which will be filled. the dvf_np is an output # patch from the network. We control the spatial location of both dvf in this function batch_im, win_center, win_r_before, win_r_after, predicted_begin, predicted_end = pair_pyramid.next_sweep_patch( ) time_before_gpu = time.time() [dvf_np] = sess.run([dvf_tf], feed_dict={ images_tf: batch_im, bn_training: 0 }) time_after_gpu = time.time() logging.debug('GPU: ' + pair_info[0]['data'] + ', CN{} center={} is done in {:.2f}s '.format( pair_info[0]['cn'], win_center, time_after_gpu - time_before_gpu)) pyr['DVF_s'+str(stage)][win_center[0] - win_r_before[0]: win_center[0] + win_r_after[0], win_center[1] - win_r_before[1]: win_center[1] + win_r_after[1], win_center[2] - win_r_before[2]: win_center[2] + win_r_after[2], :] = \ dvf_np[0, predicted_begin[0]:predicted_end[0], predicted_begin[1]:predicted_end[1], predicted_begin[2]:predicted_end[2], :] # rescaling dvf based on the voxel spacing: spacing_ref = [1.0 * stage for _ in range(3)] spacing_current = pyr['fixed_im_s' + str(stage) + '_sitk'].GetSpacing() for dim in range(3): pyr['DVF_s' + str(stage)][:, :, :, dim] = pyr['DVF_s' + str( stage)][:, :, :, dim] * spacing_current[dim] / spacing_ref[dim] pyr['DVF_s' + str(stage) + '_sitk'] = ip.array_to_sitk( pyr['DVF_s' + str(stage)], im_ref=pyr['fixed_im_s' + str(stage) + '_sitk'], is_vector=True) if i_stage == (len(setting['ImagePyramidSchedule']) - 1): # when all stages are finished, final dvf and moved image are written dvf_composed_final_sitk = 'DVF_s' + str(stage) + '_composed_sitk' if len(setting['ImagePyramidSchedule']) == 1: # need to upsample in the case that last stage is not 1 if stage == 1: pyr[dvf_composed_final_sitk] = pyr['DVF_s' + str(stage) + '_sitk'] else: pyr[dvf_composed_final_sitk] = ip.resampler_sitk( pyr['DVF_s' + str(stage) + '_sitk'], scale=1 / stage, im_ref_size=pyr['fixed_im_s1_sitk'].GetSize(), interpolator=sitk.sitkLinear) else: pyr[dvf_composed_final_sitk] = sitk.Add( pyr['DVF_s' + str(setting['ImagePyramidSchedule'][-2]) + '_composed_up_sitk'], pyr['DVF_s' + str(stage) + '_sitk']) if stage != 1: pyr[dvf_composed_final_sitk] = ip.resampler_sitk( pyr[dvf_composed_final_sitk], scale=1 / stage, im_ref_size=pyr['fixed_im_s1_sitk'].GetSize(), interpolator=sitk.sitkLinear) if not setting['WriteNoDVF']: sitk.WriteImage( sitk.Cast(pyr[dvf_composed_final_sitk], sitk.sitkVectorFloat32), su.address_generator(setting, 'dvf_s0', pair_info=pair_info, stage_list=stage_list)) dvf_t = sitk.DisplacementFieldTransform( pyr[dvf_composed_final_sitk]) pyr['moved_im_s0_sitk'] = ip.resampler_by_transform( pyr['moving_im_s1_sitk'], dvf_t, default_pixel_value=setting['data'][ pair_info[1]['data']]['DefaultPixelValue']) sitk.WriteImage( sitk.Cast(pyr['moved_im_s0_sitk'], setting['data'][pair_info[1]['data']]['ImageByte']), moved_im_s0_address) if setting['WriteMasksForLSTM']: mask_to_zero_stage = setting['network_dict'][ 'stage' + str(stage)]['MaskToZero'] if setting['read_pair_mode'] == 'synthetic': moving_mask_sitk = sitk.ReadImage( su.address_generator(setting, mask_to_zero_stage, stage=1, **im_info_su)) moved_mask_stage1 = ip.resampler_by_transform( moving_mask_sitk, dvf_t, default_pixel_value=0, interpolator=sitk.sitkNearestNeighbor) sitk.WriteImage( moved_mask_stage1, su.address_generator(setting, 'Moved' + mask_to_zero_stage + '_AG', stage=1, **im_info_su)) logging.debug('writing ' + su.address_generator(setting, 'Moved' + mask_to_zero_stage + '_AG', stage=1, **im_info_su)) time_after_dvf = time.time() logging.debug( pair_info[0]['data'] + ', CN{}, ImType{} is done in {:.2f}s '.format( pair_info[0]['cn'], pair_info[0]['type_im'], time_after_dvf - time_before_dvf)) return 0
def multi_stage(setting, network_dict, pair_info, overwrite=False): """ :param setting: :param network_dict: :param pair_info: information of the pair to be registered. :param overwrite: :return: The output moved images and dvf will be written to the disk. 1: registration is performed correctly 2: skip overwriting """ stage_list = setting['ImagePyramidSchedule'] final_moved_image_address = su.address_generator(setting, 'moved_image', pair_info=pair_info, stage=0, stage_list=stage_list) if os.path.isfile(final_moved_image_address): if not overwrite: print('overwrite=False, file '+final_moved_image_address+' already exists, skipping .....') return 2 else: print('overwrite=True, file '+final_moved_image_address+' already exists, but overwriting .....') pair_stage1 = real_pair.Images(setting, pair_info, stage=1) pyr = dict() # pyr: a dictionary of pyramid images pyr['fixed_im_s1_sitk'] = pair_stage1.get_fixed_im_sitk() pyr['moving_im_s1_sitk'] = pair_stage1.get_moved_im_affine_sitk() pyr['fixed_im_s1'] = pair_stage1.get_fixed_im() pyr['moving_im_s1'] = pair_stage1.get_moved_im_affine() if setting['torsoMask']: pyr['fixed_torso_s1_sitk'] = pair_stage1.get_fixed_torso_sitk() pyr['moving_torso_s1_sitk'] = pair_stage1.get_moved_torso_affine_sitk() if not (os.path.isdir(su.address_generator(setting, 'full_reg_folder', pair_info=pair_info, stage_list=stage_list))): os.makedirs(su.address_generator(setting, 'full_reg_folder', pair_info=pair_info, stage_list=stage_list)) time_before_dvf = time.time() for i_stage, stage in enumerate(setting['ImagePyramidSchedule']): if stage != 1: pyr['fixed_im_s'+str(stage)+'_sitk'] = ip.downsampler_gpu(pyr['fixed_im_s1_sitk'], stage, default_pixel_value=setting['data'][pair_info[0]['data']]['defaultPixelValue']) pyr['moving_im_s'+str(stage)+'_sitk'] = ip.downsampler_gpu(pyr['moving_im_s1_sitk'], stage, default_pixel_value=setting['data'][pair_info[1]['data']]['defaultPixelValue']) if setting['torsoMask']: pyr['fixed_torso_s'+str(stage)+'_sitk'] = ip.downsampler_sitk(pyr['fixed_torso_s1_sitk'], stage, im_ref=pyr['fixed_im_s' + str(stage) + '_sitk'], default_pixel_value=0, interpolator=sitk.sitkNearestNeighbor) pyr['moving_torso_s'+str(stage)+'_sitk'] = ip.downsampler_sitk(pyr['moving_torso_s1_sitk'], stage, im_ref=pyr['moving_im_s' + str(stage) + '_sitk'], default_pixel_value=0, interpolator=sitk.sitkNearestNeighbor) else: pyr['fixed_torso_s'+str(stage)+'_sitk'] = None pyr['moving_torso_s'+str(stage)+'_sitk'] = None input_regnet_moving_torso = None if i_stage == 0: input_regnet_moving = 'moving_im_s'+str(stage)+'_sitk' if setting['torsoMask']: input_regnet_moving_torso = 'moving_torso_s'+str(stage)+'_sitk' else: previous_pyramid = setting['ImagePyramidSchedule'][i_stage - 1] dvf_composed_previous_up_sitk = 'DVF_s'+str(previous_pyramid)+'_composed_up_sitk' dvf_composed_previous_sitk = 'DVF_s'+str(previous_pyramid)+'_composed_sitk' if i_stage == 1: pyr[dvf_composed_previous_sitk] = pyr['DVF_s'+str(setting['ImagePyramidSchedule'][i_stage-1])+'_sitk'] elif i_stage > 1: pyr[dvf_composed_previous_sitk] = sitk.Add(pyr['DVF_s'+str(setting['ImagePyramidSchedule'][i_stage-2])+'_composed_up_sitk'], pyr['DVF_s'+str(setting['ImagePyramidSchedule'][i_stage - 1])+'_sitk']) pyr[dvf_composed_previous_up_sitk] = ip.upsampler_gpu(pyr[dvf_composed_previous_sitk], round(previous_pyramid/stage), dvf_output_size=pyr['fixed_im_s'+str(stage)+'_sitk'].GetSize()[::-1], ) if setting['WriteAfterEachStage']: sitk.WriteImage(sitk.Cast(pyr[dvf_composed_previous_up_sitk], sitk.sitkVectorFloat32), su.address_generator(setting, 'dvf_s_up', pair_info=pair_info, stage=previous_pyramid, stage_list=stage_list)) dvf_t = sitk.DisplacementFieldTransform(pyr[dvf_composed_previous_up_sitk]) # after this line DVF_composed_previous_up_sitk is converted to a transform. so we need to load it again. pyr['moved_im_s'+str(stage)+'_sitk'] = ip.resampler_by_dvf(pyr['moving_im_s' + str(stage)+'_sitk'], dvf_t, default_pixel_value=setting['data'][pair_info[1]['data']]['defaultPixelValue']) if setting['torsoMask']: pyr['moved_torso_s'+str(stage)+'_sitk'] = ip.resampler_by_dvf(pyr['moving_torso_s'+str(stage)+'_sitk'], dvf_t, default_pixel_value=0, interpolator=sitk.sitkNearestNeighbor) pyr[dvf_composed_previous_up_sitk] = dvf_t.GetDisplacementField() if setting['WriteAfterEachStage']: sitk.WriteImage(sitk.Cast(pyr['moved_im_s'+str(stage)+'_sitk'], setting['data'][pair_info[1]['data']]['imageByte']), su.address_generator(setting, 'moved_image', pair_info=pair_info, stage=stage, stage_list=stage_list)) input_regnet_moving = 'moved_im_s'+str(stage)+'_sitk' if setting['torsoMask']: input_regnet_moving_torso = 'moved_torso_s'+str(stage)+'_sitk' pyr['DVF_s'+str(stage)] = np.zeros(np.r_[pyr['fixed_im_s'+str(stage)+'_sitk'].GetSize()[::-1], 3], dtype=np.float64) pair_pyramid = real_pair.Images(setting, pair_info, stage=stage, fixed_im_sitk=pyr['fixed_im_s'+str(stage)+'_sitk'], moved_im_affine_sitk=pyr[input_regnet_moving], fixed_torso_sitk=pyr['fixed_torso_s'+str(stage)+'_sitk'], moved_torso_affine_sitk=pyr[input_regnet_moving_torso] ) # building and loading network tf.reset_default_graph() setting['R'] = network_dict['Stage'+str(stage)]['R'] # Radius of normal resolution patch size. Total size is (2*R +1) setting['Ry'] = network_dict['Stage'+str(stage)]['Ry'] # Radius of output. Total size is (2*Ry +1) setting['Ry_erode'] = network_dict['Stage'+str(stage)]['Ry_erode'] # at the test time, sometimes there are some problems at the border images_tf = tf.placeholder(tf.float32, shape=[None, 2 * setting['R'] + 1, 2 * setting['R'] + 1, 2 * setting['R'] + 1, 2], name="Images") bn_training = tf.placeholder(tf.bool, name='bn_training') x_fixed = images_tf[:, :, :, :, 0, np.newaxis] x_deformed = images_tf[:, :, :, :, 1, np.newaxis] dvf_tf = getattr(RegNetModel, network_dict['Stage'+str(stage)]['NetworkDesign'])(x_fixed, x_deformed, bn_training) extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) logging.debug(' Total number of variables %s' % (np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))) saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) sess = tf.Session() sess.run(tf.global_variables_initializer()) saver.restore(sess, su.address_generator(setting, 'saved_model_with_step', current_experiment=network_dict['Stage'+str(stage)]['NetworkLoad'], step=network_dict['Stage'+str(stage)]['GlobalStepLoad'])) while not pair_pyramid.get_sweep_completed(): # The pyr[DVF_S] is the numpy DVF which will be filled. the dvf_np is an output # path from the network. We control the spatial location of both dvf in this function batch_im, win_center, win_r_before, win_r_after, predicted_begin, predicted_end = pair_pyramid.next_sweep_patch() time_before_gpu = time.time() [dvf_np] = sess.run([dvf_tf], feed_dict={images_tf: batch_im, bn_training: 0}) time_after_gpu = time.time() logging.debug('GPU: Data='+pair_info[0]['data']+' CN = {} center = {} is done in {:.2f}s '.format( pair_info[0]['cn'], win_center, time_after_gpu - time_before_gpu)) pyr['DVF_s'+str(stage)][win_center[0] - win_r_before[0]: win_center[0] + win_r_after[0], win_center[1] - win_r_before[1]: win_center[1] + win_r_after[1], win_center[2] - win_r_before[2]: win_center[2] + win_r_after[2], :] = \ dvf_np[0, predicted_begin[0]:predicted_end[0], predicted_begin[1]:predicted_end[1], predicted_begin[2]:predicted_end[2], :] pyr['DVF_s'+str(stage)+'_sitk'] = ip.array_to_sitk(pyr['DVF_s' + str(stage)], im_ref=pyr['fixed_im_s'+str(stage)+'_sitk'], is_vector=True) if i_stage == (len(setting['ImagePyramidSchedule'])-1): # when all stages are finished, final dvf and moved image are written dvf_composed_final_sitk = 'DVF_s'+str(stage)+'_composed_sitk' if len(setting['ImagePyramidSchedule']) == 1: pyr[dvf_composed_final_sitk] = pyr['DVF_s'+str(stage)+'_sitk'] else: pyr[dvf_composed_final_sitk] = sitk.Add(pyr['DVF_s'+str(setting['ImagePyramidSchedule'][-2])+'_composed_up_sitk'], pyr['DVF_s'+str(stage)+'_sitk']) sitk.WriteImage(sitk.Cast(pyr[dvf_composed_final_sitk], sitk.sitkVectorFloat32), su.address_generator(setting, 'dvf_s0', pair_info=pair_info, stage_list=stage_list)) dvf_t = sitk.DisplacementFieldTransform(pyr[dvf_composed_final_sitk]) pyr['moved_im_s0_sitk'] = ip.resampler_by_dvf(pyr['moving_im_s'+str(stage)+'_sitk'], dvf_t, default_pixel_value=setting['data'][pair_info[1]['data']]['defaultPixelValue']) sitk.WriteImage(sitk.Cast(pyr['moved_im_s0_sitk'], setting['data'][pair_info[1]['data']]['imageByte']), su.address_generator(setting, 'moved_image', pair_info=pair_info, stage=0, stage_list=stage_list)) time_after_dvf = time.time() logging.debug('Data='+pair_info[0]['data']+' CN = {} ImType = {} is done in {:.2f}s '.format( pair_info[0]['cn'], pair_info[0]['type_im'],time_after_dvf - time_before_dvf)) return 1