def __init__(self, dataset_name, base_folder, output_folder, model_file_name, num_embeddings, image_size, coord_factors, min_samples, sigma, border_size, parent_dilation, parent_frame_search): self.coord_factors = coord_factors self.min_samples = min_samples self.sigma = sigma self.border_size = border_size self.parent_dilation = parent_dilation self.parent_frame_search = parent_frame_search config = tf.ConfigProto() config.gpu_options.allow_growth = True self.sess = tf.Session(config=config) self.coord = tf.train.Coordinator() self.image_prefix = 'mask' self.track_file_name = 'res_track.txt' self.save_debug_images = False self.image_size = image_size self.data_format = 'channels_last' self.output_size = self.image_size self.num_embeddings = num_embeddings self.base_folder = base_folder self.output_folder = output_folder self.load_model_filename = model_file_name additional_scale = {'DIC-C2DH-HeLa': [1, 1], 'Fluo-C2DL-MSC': [0.95, 0.95], 'Fluo-N2DH-GOWT1': [1, 1], 'Fluo-N2DH-SIM+': [1, 1], 'Fluo-N2DL-HeLa': [1, 1], 'PhC-C2DH-U373': [1, 1], 'PhC-C2DL-PSC': [1, 1]} additional_scale = additional_scale[dataset_name] normalization_consideration_factors = {'DIC-C2DH-HeLa': (0.2, 0.1), 'Fluo-C2DL-MSC': (0.2, 0.01), 'Fluo-N2DH-GOWT1': (0.2, 0.1), 'Fluo-N2DH-SIM+': (0.2, 0.01), 'Fluo-N2DL-HeLa': (0.2, 0.1), 'PhC-C2DH-U373': (0.2, 0.1), 'PhC-C2DL-PSC': (0.2, 0.1)} normalization_consideration_factor = normalization_consideration_factors[dataset_name] self.dataset = Dataset(self.image_size, base_folder=self.base_folder, data_format=self.data_format, save_debug_images=self.save_debug_images, normalization_consideration_factors=normalization_consideration_factor, additional_scale=additional_scale, image_gaussian_blur_sigma=2.0, pad_image=False) self.dataset_val = self.dataset.dataset_val_single_frame() if self.data_format == 'channels_first': self.channel_axis = 1 self.time_stack_axis = 1 else: self.channel_axis = 3 self.time_stack_axis = 0
def __init__(self, dataset_name, input_image_folder, output_folder, model_file_name, num_embeddings, num_frames): self.dataset_name = dataset_name self.tf_config = tf.ConfigProto() self.tf_config.gpu_options.per_process_gpu_memory_fraction = 0.5 self.tf_config.gpu_options.allow_growth = True self.sess = tf.Session(config=self.tf_config) self.coord = tf.train.Coordinator() self.image_prefix = 'mask' self.track_file_name = 'res_track.txt' self.save_all_embeddings = False self.save_all_input_images = False self.save_all_predictions = True self.image_size, self.test_image_size = image_sizes_for_dataset_name( dataset_name) self.tiled_increment = [128, 128] self.instances_ignore_border = [32, 32] self.num_frames = num_frames self.data_format = 'channels_last' self.output_size = self.image_size self.num_embeddings = num_embeddings self.input_image_folder = input_image_folder self.output_folder = output_folder self.load_model_filename = model_file_name self.dataset_parameters = get_dataset_parameters(dataset_name) self.instance_image_creator_parameters = get_instance_image_creator_parameters( dataset_name) self.instance_tracker_parameters = get_instance_tracker_parameters( dataset_name) self.dataset = Dataset( self.test_image_size, base_folder=self.input_image_folder, data_format=self.data_format, debug_image_folder=os.path.join(self.output_folder, 'input_images') if self.save_all_input_images else None, **self.dataset_parameters) self.dataset_val = self.dataset.dataset_val_single_frame() self.video_frames = glob(self.input_image_folder + '/*.tif') self.video_frames = sorted([ os.path.splitext(os.path.basename(frame))[0][1:] for frame in self.video_frames ]) if self.data_format == 'channels_first': self.channel_axis = 1 self.time_stack_axis = 1 else: self.channel_axis = 3 self.time_stack_axis = 0
def __init__(self, dataset_name, input_image_folder, output_folder, model_file_name, num_embeddings, image_size, additional_scale, normalization_consideration_factors, coord_factors, min_samples, sigma, border_size, parent_dilation, parent_frame_search): self.coord_factors = coord_factors self.min_samples = min_samples self.sigma = sigma self.border_size = border_size self.parent_dilation = parent_dilation self.parent_frame_search = parent_frame_search self.sess = tf.Session() self.coord = tf.train.Coordinator() self.image_prefix = 'mask' self.track_file_name = 'res_track.txt' self.save_all_embeddings = True self.save_all_input_images = True self.save_all_predictions = True self.save_debug_images = False self.image_size = image_size self.data_format = 'channels_last' self.output_size = self.image_size self.num_embeddings = num_embeddings self.input_image_folder = input_image_folder self.output_folder = output_folder self.load_model_filename = model_file_name self.dataset = Dataset(self.image_size, base_folder=self.input_image_folder, data_format=self.data_format, save_debug_images=self.save_debug_images, normalization_consideration_factors= normalization_consideration_factors, additional_scale=additional_scale, image_gaussian_blur_sigma=2.0, pad_image=False) self.dataset_val = self.dataset.dataset_val_single_frame() self.video_frames = glob.glob(self.input_image_folder + '/*.tif') self.video_frames = sorted([ os.path.splitext(os.path.basename(frame))[0][1:] for frame in self.video_frames ]) if self.data_format == 'channels_first': self.channel_axis = 1 self.time_stack_axis = 1 else: self.channel_axis = 3 self.time_stack_axis = 0
class MainLoop(object): def __init__(self, dataset_name, input_image_folder, output_folder, model_file_name, num_embeddings, num_frames): self.dataset_name = dataset_name self.sess = tf.Session() self.coord = tf.train.Coordinator() self.image_prefix = 'mask' self.track_file_name = 'res_track.txt' self.save_all_embeddings = False self.save_all_input_images = False self.save_all_predictions = True self.image_size, self.test_image_size = image_sizes_for_dataset_name( dataset_name) self.tiled_increment = [128, 128] self.instances_ignore_border = [32, 32] self.num_frames = num_frames self.data_format = 'channels_first' self.output_size = self.image_size self.num_embeddings = num_embeddings self.input_image_folder = input_image_folder self.output_folder = output_folder self.load_model_filename = model_file_name self.dataset_parameters = get_dataset_parameters(dataset_name) self.instance_image_creator_parameters = get_instance_image_creator_parameters( dataset_name) self.instance_tracker_parameters = get_instance_tracker_parameters( dataset_name) self.dataset = Dataset( self.test_image_size, base_folder=self.input_image_folder, data_format=self.data_format, debug_image_folder=os.path.join(self.output_folder, 'input_images') if self.save_all_input_images else None, **self.dataset_parameters) self.dataset_val = self.dataset.dataset_val_single_frame() self.video_frames = glob(self.input_image_folder + '/*.tif') self.video_frames = sorted([ os.path.splitext(os.path.basename(frame))[0][1:] for frame in self.video_frames ]) if self.data_format == 'channels_first': self.channel_axis = 1 self.time_stack_axis = 1 else: self.channel_axis = 3 self.time_stack_axis = 0 def create_output_folder(self): create_directories(self.output_folder) def load_model(self): self.saver = tf.train.Saver() print('Restoring model ' + self.load_model_filename) self.saver.restore(self.sess, self.load_model_filename) def init_all(self): self.init_networks() self.load_model() self.create_output_folder() def run_test(self): self.init_all() print('Starting main test loop') self.test() def init_networks(self): network_image_size = self.image_size network_output_size = self.output_size if self.data_format == 'channels_first': data_generator_entries = OrderedDict([ ('image', [1, self.num_frames] + network_image_size), ('instances_merged', [None] + network_output_size), ('instances_bac', [None] + network_output_size) ]) else: data_generator_entries = OrderedDict([ ('image', network_image_size + [1]), ('instances_merged', network_output_size + [None]), ('instances_bac', network_output_size + [None]) ]) # build val graph val_placeholders = create_placeholders(data_generator_entries, shape_prefix=[1]) self.data_val = val_placeholders['image'] with tf.variable_scope('net'): self.embeddings_0, self.embeddings_1 = network( self.data_val, num_outputs_embedding=self.num_embeddings, data_format=self.data_format, actual_network=HourglassNet3D, is_training=False) self.embeddings_normalized_0 = tf.nn.l2_normalize( self.embeddings_0, dim=self.channel_axis) self.embeddings_normalized_1 = tf.nn.l2_normalize( self.embeddings_1, dim=self.channel_axis) self.embeddings_cropped_val = (self.embeddings_normalized_0, self.embeddings_normalized_1) def test_cropped_image(self, dataset_entry, return_all_intermediate_embeddings=False): generators = dataset_entry['generators'] full_image = generators['image'] # initialize sizes based on data_format fetches = self.embeddings_cropped_val if self.data_format == 'channels_first': image_size_np = [1, self.num_frames] + list( reversed(self.image_size)) full_image_size_np = list(full_image.shape) embeddings_size_np = [self.num_embeddings, self.num_frames] + list( reversed(self.image_size)) full_embeddings_size_np = [self.num_embeddings] + list( full_image.shape[1:]) inc = [0, 0] + list(reversed(self.tiled_increment)) else: image_size_np = list(reversed(self.image_size)) + [1] full_image_size_np = list(full_image.shape) embeddings_size_np = list(reversed( self.image_size)) + [self.num_embeddings] full_embeddings_size_np = list( full_image.shape[0:2]) + [self.num_embeddings] inc = list(reversed(self.tiled_increment)) + [0] # initialize on image tiler for the input and a list of image tilers for the embeddings image_tiler = ImageTiler(full_image_size_np, image_size_np, inc, True, -1) embeddings_tilers = tuple([ ImageTiler(full_embeddings_size_np, embeddings_size_np, inc, True, -1) for _ in range(len(self.embeddings_cropped_val)) ]) all_intermediate_embeddings = [] for state_index, all_tilers in enumerate( zip(*((image_tiler, ) + embeddings_tilers))): image_tiler = all_tilers[0] embeddings_tilers = all_tilers[1:] current_image = image_tiler.get_current_data(full_image) feed_dict = {self.data_val: np.expand_dims(current_image, axis=0)} run_tuple = self.sess.run(fetches, feed_dict) image_tiler.set_current_data(current_image) for i, embeddings_tiler in enumerate(embeddings_tilers): embeddings = np.squeeze(run_tuple[i], axis=0) if return_all_intermediate_embeddings and i == len( embeddings_tilers) - 1: all_intermediate_embeddings.append(embeddings) embeddings_tiler.set_current_data(embeddings) embeddings = [ embeddings_tiler.output_image for embeddings_tiler in embeddings_tilers ] if return_all_intermediate_embeddings: return embeddings, all_intermediate_embeddings else: return embeddings def merge_tiled_instances(self, tiled_instances): # initialize sizes based on data_format instances_size_np = [2] + list(reversed(self.image_size)) full_instances_size_np = [2] + list(reversed(self.test_image_size)) inc = [0] + list(reversed(self.tiled_increment)) # initialize on image tiler for the input and a list of image tilers for the embeddings instance_tiler = ImageTiler(full_instances_size_np, instances_size_np, inc, True, 0, output_image_dtype=np.uint16) instance_merger = InstanceMerger( ignore_border=self.instances_ignore_border) for i, instance_tiler in enumerate(instance_tiler): current_instance_pair = tiled_instances[i] instance_tiler.set_current_data( current_instance_pair, instance_merger.merge_as_larger_instances, merge_whole_image=True) instances = instance_tiler.output_image return instances def get_instances(self, stacked_two_embeddings): clusterer = InstanceImageCreator( **self.instance_image_creator_parameters) clusterer.create_instance_image(stacked_two_embeddings) instances = clusterer.label_image return instances def get_merged_instances(self, stacked_two_embeddings_tile_list): tiled_instances = [] for stacked_two_embeddings in stacked_two_embeddings_tile_list: if self.data_format == 'channels_last': stacked_two_embeddings = np.transpose(stacked_two_embeddings, [3, 0, 1, 2]) instances = self.get_instances(stacked_two_embeddings) tiled_instances.append(instances) merged_instances = self.merge_tiled_instances(tiled_instances) return merged_instances def test(self): if len(self.video_frames) == 0: print('No images found!') return print('Testing...', self.input_image_folder) video_frames_all = self.video_frames frame_index = 0 instance_tracker = InstanceTracker(**self.instance_tracker_parameters) for j in range(len(video_frames_all) - self.num_frames + 1): print('Processing frame', j) video_frames = video_frames_all[j:j + self.num_frames] current_all_intermediate_embeddings = [] images = [] for k, video_frame in enumerate(video_frames): dataset_entry = self.dataset_val.get({'image_id': video_frame}) image = dataset_entry['generators']['image'] images.append(image) stacked_image = np.stack(images, axis=1) print(stacked_image.shape) utils.io.image.write_np( stacked_image, os.path.join(self.output_folder, 'input', 'frame_' + str(j) + '.mha')) full_dataset_entry = {'generators': {'image': stacked_image}} embeddings_cropped, all_intermediate_embeddings = self.test_cropped_image( full_dataset_entry, return_all_intermediate_embeddings=True) current_all_intermediate_embeddings.append( all_intermediate_embeddings) if j == 0: for i in range(self.num_frames - 1): stacked_two_embeddings_tile_list = [] for tile_i in range( len(current_all_intermediate_embeddings[0])): stacked_two_embeddings_tile_list.append( np.stack([ current_all_intermediate_embeddings[0][tile_i] [:, i, :, :], current_all_intermediate_embeddings[0][tile_i] [:, i + 1, :, :] ], axis=self.time_stack_axis)) instances = self.get_merged_instances( stacked_two_embeddings_tile_list) instance_tracker.add_new_label_image(instances) if self.save_all_embeddings: for tile_i, e in enumerate( stacked_two_embeddings_tile_list): utils.io.image.write_np( e, os.path.join( self.output_folder, 'embeddings', 'frame_' + str(j + i).zfill(3) + '_tile_' + str(tile_i).zfill(2) + '.mha'), compress=False) else: stacked_two_embeddings_tile_list = [] for tile_i in range(len( current_all_intermediate_embeddings[0])): stacked_two_embeddings_tile_list.append( np.stack([ current_all_intermediate_embeddings[0][tile_i] [:, num_frames - 2, :, :], current_all_intermediate_embeddings[0][tile_i] [:, num_frames - 1, :, :] ], axis=self.time_stack_axis)) instances = self.get_merged_instances( stacked_two_embeddings_tile_list) instance_tracker.add_new_label_image(instances) if self.save_all_embeddings: for tile_i, e in enumerate( stacked_two_embeddings_tile_list): utils.io.image.write_np( e, os.path.join( self.output_folder, 'embeddings', 'frame_' + str(frame_index).zfill(3) + '_tile_' + str(tile_i).zfill(2) + '.mha'), compress=False) if j == 0: frame_index += self.num_frames - 1 else: frame_index += 1 if self.save_all_predictions: utils.io.image.write_np( instance_tracker.stacked_label_image.astype(np.uint16), os.path.join(self.output_folder, 'predictions', 'merged_instances.mha')) to_size = dataset_entry['datasources']['image'].GetSize() transformation = dataset_entry['transformations']['image'] #transformation = scale_transformation_for_image_sizes(from_size, to_size, [0.95, 0.95] if self.dataset_name == 'Fluo-C2DL-MSC' else [1.0, 1.0]) instance_tracker.resample_stacked_label_image(to_size, transformation, 1.0) if self.save_all_predictions: utils.io.image.write_np( instance_tracker.stacked_label_image.astype(np.uint16), os.path.join(self.output_folder, 'predictions', 'merged_instances_resampled.mha')) instance_tracker.finalize() if self.save_all_predictions: utils.io.image.write_np( instance_tracker.stacked_label_image.astype(np.uint16), os.path.join(self.output_folder, 'predictions', 'merged_instances_final.mha')) track_tuples = instance_tracker.track_tuples final_track_image_np = instance_tracker.stacked_label_image print('Saving output images and tracks...') final_track_images_sitk = [ utils.sitk_np.np_to_sitk(np.squeeze(im)) for im in np.split( final_track_image_np, final_track_image_np.shape[0], axis=0) ] for i, final_track_image_sitk in enumerate(final_track_images_sitk): video_frame = str(i).zfill(3) utils.io.image.write( final_track_image_sitk, os.path.join(self.output_folder, self.image_prefix + video_frame + '.tif')) utils.io.text.save_list_csv(track_tuples, os.path.join(self.output_folder, self.track_file_name), delimiter=' ')
class MainLoop(object): def __init__(self, dataset_name, input_image_folder, output_folder, model_file_name, num_embeddings, image_size, additional_scale, normalization_consideration_factors, coord_factors, min_samples, sigma, border_size, parent_dilation, parent_frame_search): self.coord_factors = coord_factors self.min_samples = min_samples self.sigma = sigma self.border_size = border_size self.parent_dilation = parent_dilation self.parent_frame_search = parent_frame_search self.sess = tf.Session() self.coord = tf.train.Coordinator() self.image_prefix = 'mask' self.track_file_name = 'res_track.txt' self.save_all_embeddings = True self.save_all_input_images = True self.save_all_predictions = True self.save_debug_images = False self.image_size = image_size self.data_format = 'channels_last' self.output_size = self.image_size self.num_embeddings = num_embeddings self.input_image_folder = input_image_folder self.output_folder = output_folder self.load_model_filename = model_file_name self.dataset = Dataset(self.image_size, base_folder=self.input_image_folder, data_format=self.data_format, save_debug_images=self.save_debug_images, normalization_consideration_factors= normalization_consideration_factors, additional_scale=additional_scale, image_gaussian_blur_sigma=2.0, pad_image=False) self.dataset_val = self.dataset.dataset_val_single_frame() self.video_frames = glob.glob(self.input_image_folder + '/*.tif') self.video_frames = sorted([ os.path.splitext(os.path.basename(frame))[0][1:] for frame in self.video_frames ]) if self.data_format == 'channels_first': self.channel_axis = 1 self.time_stack_axis = 1 else: self.channel_axis = 3 self.time_stack_axis = 0 def create_output_folder(self): create_directories(self.output_folder) def load_model(self): self.saver = tf.train.Saver() print('Restoring model ' + self.load_model_filename) self.saver.restore(self.sess, self.load_model_filename) def init_all(self): self.init_networks() self.load_model() self.create_output_folder() def run_test(self): self.init_all() print('Starting main test loop') self.test() def init_networks(self): network_image_size = self.image_size network_output_size = self.output_size if self.data_format == 'channels_first': data_generator_entries = OrderedDict([ ('image', [1] + network_image_size), ('instances_merged', [None] + network_output_size), ('instances_bac', [None] + network_output_size) ]) else: data_generator_entries = OrderedDict([ ('image', network_image_size + [1]), ('instances_merged', network_output_size + [None]), ('instances_bac', network_output_size + [None]) ]) # build val graph val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders( data_generator_entries, shape_prefix=[1]) self.data_val = val_placeholders['image'] self.tracking_val = val_placeholders['instances_merged'] self.instances_bac_val = val_placeholders['instances_bac'] with tf.variable_scope('net/rnn'): self.embeddings_0, self.embeddings_1, self.lstm_input_states, self.lstm_output_states = network_single_frame_with_lstm_states( self.data_val, num_outputs_embedding=self.num_embeddings, data_format=self.data_format) self.embeddings_normalized_0 = tf.nn.l2_normalize( self.embeddings_0, dim=self.channel_axis) self.embeddings_normalized_1 = tf.nn.l2_normalize( self.embeddings_1, dim=self.channel_axis) def test(self): #label_stack = utils.sitk_np.sitk_to_np(utils.io.image.read(os.path.join(self.output_folder, 'label_stack.mha'), sitk_pixel_type=sitk.sitkVectorUInt16)) #label_stack = np.transpose(label_stack, [0, 3, 1, 2]) if len(self.video_frames) == 0: print('No images found!') return print('Testing...', self.input_image_folder) tracker = EmbeddingTracker( coord_factors=self.coord_factors, stack_neighboring_slices=2, min_cluster_size=self.min_samples, min_samples=self.min_samples, min_label_size_per_stack=self.min_samples / 2, save_label_stack=True, image_ignore_border=self.border_size, parent_search_dilation_size=self.parent_dilation, max_parent_search_frames=self.parent_frame_search) #tracker.set_label_stack(label_stack) first = True current_images = [] current_lstm_states = [] for i, video_frame in enumerate(self.video_frames): with Timer('processing video frame ' + str(video_frame)): dataset_entry = self.dataset_val.get({'image_id': video_frame}) generators = dataset_entry['generators'] feed_dict = { self.data_val: np.expand_dims(generators['image'], axis=0) } if not first: for j in range(len(self.lstm_input_states)): feed_dict[ self.lstm_input_states[j]] = current_lstm_states[j] run_tuple = self.sess.run([self.embeddings_normalized_1] + list(self.lstm_output_states), feed_dict=feed_dict) embeddings_normalized_1 = np.squeeze(run_tuple[0], axis=0) current_lstm_states = run_tuple[1:] if self.data_format == 'channels_last': embeddings_normalized_1 = np.transpose( embeddings_normalized_1, [2, 0, 1]) tracker.add_slice(embeddings_normalized_1) if self.save_all_input_images: current_images.append(generators['image']) first = False # finalize tracker and resample to input resolution transformations = dataset_entry['transformations'] transformation = transformations['image'] datasources = dataset_entry['datasources'] input_image = datasources['image'] #utils.io.image.write_np(np.stack(tracker.label_stack_list, axis=1), os.path.join(self.output_folder, 'label_stack.mha')) tracker.finalize() tracker.resample_stacked_label_image(input_image, transformation, self.sigma) tracker.fix_tracks_after_resampling() track_tuples = tracker.track_tuples final_track_image_np = tracker.stacked_label_image print('Saving output images and tracks...') final_track_images_sitk = [ utils.sitk_np.np_to_sitk(np.squeeze(im)) for im in np.split( final_track_image_np, final_track_image_np.shape[0], axis=0) ] for video_frame, final_track_image_sitk in zip( self.video_frames, final_track_images_sitk): utils.io.image.write( final_track_image_sitk, os.path.join(self.output_folder, self.image_prefix + video_frame + '.tif')) utils.io.text.save_list_csv(track_tuples, os.path.join(self.output_folder, self.track_file_name), delimiter=' ') if self.save_all_embeddings: # embeddings are always 'channels_first' embeddings = np.stack(tracker.embeddings_slices, axis=1) utils.io.image.write_np( embeddings, os.path.join(self.output_folder, 'embeddings.mha'), 'channels_first') if self.save_all_input_images: images = np.stack(current_images, axis=self.time_stack_axis) utils.io.image.write_np( images, os.path.join(self.output_folder, 'image.mha'), self.data_format) if self.save_all_predictions: predictions = np.stack(tracker.stacked_label_image, axis=self.time_stack_axis) utils.io.image.write_np( predictions, os.path.join(self.output_folder, 'predictions.mha'), self.data_format)
class MainLoop(object): def __init__(self, dataset_name, base_folder, output_folder, model_file_name, num_embeddings, image_size, coord_factors, min_samples, sigma, border_size, parent_dilation, parent_frame_search): self.coord_factors = coord_factors self.min_samples = min_samples self.sigma = sigma self.border_size = border_size self.parent_dilation = parent_dilation self.parent_frame_search = parent_frame_search config = tf.ConfigProto() config.gpu_options.allow_growth = True self.sess = tf.Session(config=config) self.coord = tf.train.Coordinator() self.image_prefix = 'mask' self.track_file_name = 'res_track.txt' self.save_debug_images = False self.image_size = image_size self.data_format = 'channels_last' self.output_size = self.image_size self.num_embeddings = num_embeddings self.base_folder = base_folder self.output_folder = output_folder self.load_model_filename = model_file_name additional_scale = { 'DIC-C2DH-HeLa': [1, 1], 'Fluo-C2DL-MSC': [0.95, 0.95], 'Fluo-N2DH-GOWT1': [1, 1], 'Fluo-N2DH-SIM+': [1, 1], 'Fluo-N2DL-HeLa': [1, 1], 'PhC-C2DH-U373': [1, 1], 'PhC-C2DL-PSC': [1, 1] } additional_scale = additional_scale[dataset_name] normalization_consideration_factors = { 'DIC-C2DH-HeLa': (0.2, 0.1), 'Fluo-C2DL-MSC': (0.2, 0.01), 'Fluo-N2DH-GOWT1': (0.2, 0.1), 'Fluo-N2DH-SIM+': (0.2, 0.01), 'Fluo-N2DL-HeLa': (0.2, 0.1), 'PhC-C2DH-U373': (0.2, 0.1), 'PhC-C2DL-PSC': (0.2, 0.1) } normalization_consideration_factor = normalization_consideration_factors[ dataset_name] self.dataset = Dataset(self.image_size, base_folder=self.base_folder, data_format=self.data_format, save_debug_images=self.save_debug_images, normalization_consideration_factors= normalization_consideration_factor, additional_scale=additional_scale, image_gaussian_blur_sigma=2.0, pad_image=False) self.dataset_val = self.dataset.dataset_val_single_frame() if self.data_format == 'channels_first': self.channel_axis = 1 self.time_stack_axis = 1 else: self.channel_axis = 3 self.time_stack_axis = 0 def __del__(self): self.sess.close() tf.reset_default_graph() # def loadModel(self): # saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net/first_frame_net') + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net/data_conv')) # model_filename = 'weights/model-' + str(self.current_iter) # print('Restoring model ' + model_filename) # saver.restore(self.sess, model_filename) def init_networks(self): network_image_size = self.image_size network_output_size = self.output_size if self.data_format == 'channels_first': data_generator_entries = OrderedDict([ ('image', [1] + network_image_size), ('instances_merged', [None] + network_output_size), ('instances_bac', [None] + network_output_size) ]) else: data_generator_entries = OrderedDict([ ('image', network_image_size + [1]), ('instances_merged', network_output_size + [None]), ('instances_bac', network_output_size + [None]) ]) # build val graph val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders( data_generator_entries, shape_prefix=[1]) self.data_val = val_placeholders['image'] self.tracking_val = val_placeholders['instances_merged'] self.instances_bac_val = val_placeholders['instances_bac'] with tf.variable_scope('net/rnn'): self.embeddings_val, self.embeddings_2_val, self.lstm_input_states_val, self.lstm_output_states_val = network_single_frame_with_lstm_states( self.data_val, num_outputs_embedding=self.num_embeddings, data_format=self.data_format) self.embeddings_normalized_val = tf.nn.l2_normalize( self.embeddings_val, dim=self.channel_axis) self.embeddings_normalized_2_val = tf.nn.l2_normalize( self.embeddings_2_val, dim=self.channel_axis) def test(self): print('Testing...', self.base_folder) video_frames = glob.glob(self.base_folder + '*.tif') video_frames = sorted([ os.path.splitext(os.path.basename(frame))[0][1:] for frame in video_frames ]) #video_frames = video_frames[100:] #coord_factors = 0.001 #min_cluster_size = 100 #min_samples = 100 #min_label_size_per_stack = 100 tracker = EmbeddingTracker( coord_factors=self.coord_factors, stack_neighboring_slices=2, min_cluster_size=self.min_samples, min_samples=self.min_samples, min_label_size_per_stack=self.min_samples, save_label_stack=True, image_ignore_border=self.border_size, parent_search_dilation_size=self.parent_dilation, max_parent_search_frames=self.parent_frame_search) first = True current_predictions = [] current_predictions_2 = [] current_images = [] # reset_every_frames = 20 for i, video_frame in enumerate(video_frames): #if int(video_frame) < 150 or int(video_frame) > 250: # continue with Timer('processing video frame ' + str(video_frame)): dataset_entry = self.dataset_val.get({'image_id': video_frame}) datasources = dataset_entry['datasources'] generators = dataset_entry['generators'] feed_dict = { self.data_val: np.expand_dims(generators['image'], axis=0) } # run loss and update loss accumulators if not first: for i in range(len(self.lstm_input_states_val)): feed_dict[self.lstm_input_states_val[ i]] = current_lstm_states[i] run_tuple = self.sess.run([ self.embeddings_normalized_val, self.embeddings_normalized_2_val ] + list(self.lstm_output_states_val), feed_dict=feed_dict) # print(iv[0].decode()) embeddings_softmax = np.squeeze(run_tuple[0], axis=0) embeddings_softmax_2 = np.squeeze(run_tuple[1], axis=0) current_lstm_states = run_tuple[2:] #current_predictions.append(embeddings_softmax) #current_predictions_2.append(embeddings_softmax_2) current_images.append(generators['image']) # current_instances.append(instance_segmentation_test.get_instances_cosine_kmeans_2d(embeddings_softmax)) first = False datasources = dataset_entry['datasources'] input_image = datasources['image'] transformations = dataset_entry['transformations'] transformation = transformations['image'] # embeddings_original = utils.sitk_image.transform_np_output_to_sitk_input(embeddings_softmax_2, # output_spacing=None, # channel_axis=2, # input_image_sitk=input_image, # transform=transformation, # interpolator='linear', # output_pixel_type=sitk.sitkFloat32) # embeddings_softmax_2 = utils.sitk_np.sitk_list_to_np(embeddings_original, axis=2) current_predictions_2.append(embeddings_softmax_2) tracker.add_slice(np.transpose(embeddings_softmax_2, [2, 0, 1])) if tracker.stacked_label_image is not None: utils.io.image.write_np( tracker.stacked_label_image, os.path.join(self.output_folder, 'merged.mha')) # if not first and i % reset_every_frames != 0: # run_tuple = self.sess.run([self.embeddings_normalized_val, self.embeddings_normalized_2_val] + list(self.lstm_output_states_val), feed_dict=feed_dict) # embeddings_softmax_2 = np.squeeze(run_tuple[1], axis=0) # tracker.add_reset_slice(np.transpose(embeddings_softmax_2, [2, 0, 1])) # prediction = np.stack(current_predictions, axis=self.time_stack_axis) # del current_predictions # utils.io.image.write_np(prediction, os.path.join(self.output_folder, 'embeddings.mha'), self.data_format) # del prediction prediction_2 = np.stack(current_predictions_2, axis=self.time_stack_axis) del current_predictions_2 utils.io.image.write_np( prediction_2, os.path.join(self.output_folder, 'embeddings_2.mha'), self.data_format) del prediction_2 images = np.stack(current_images, axis=self.time_stack_axis) del current_images utils.io.image.write_np(images, os.path.join(self.output_folder, 'image.mha'), self.data_format) del images transformations = dataset_entry['transformations'] transformation = transformations['image'] sitk.WriteTransform(transformation, os.path.join(self.output_folder, 'transform.txt')) #if self.data_format == 'channels_last': # prediction_2 = np.transpose(prediction_2, [3, 0, 1, 2]) # two_slices = tracker.get_instances_cosine_dbscan_slice_by_slice(prediction_2) # utils.io.image.write_np(two_slices, os.path.join(self.output_folder, 'two_slices.mha')) # merged = tracker.merge_consecutive_slices(two_slices, slice_neighbour_size=2) # utils.io.image.write_np(merged, os.path.join(self.output_folder, 'merged.mha'), self.data_format) datasources = dataset_entry['datasources'] input_image = datasources['image'] if self.sigma == 1: interpolator = 'label_gaussian' else: interpolator = 'nearest' tracker.finalize() track_tuples = tracker.track_tuples merged = tracker.stacked_label_image final_predictions = utils.sitk_image.transform_np_output_to_sitk_input( merged, output_spacing=None, channel_axis=0, input_image_sitk=input_image, transform=transformation, interpolator=interpolator, output_pixel_type=sitk.sitkUInt16) #final_predictions = [utils.sitk_np.np_to_sitk(np.squeeze(im), type=np.uint16) for im in np.split(merged, merged.shape[0], axis=0)] #final_predictions_smoothed_2 = [utils.sitk_image.apply_np_image_function(im, lambda x: self.label_smooth(x, sigma=2)) for im in final_predictions] if self.sigma > 1: final_predictions = [ utils.sitk_image.apply_np_image_function( im, lambda x: self.label_smooth(x, sigma=self.sigma)) for im in final_predictions ] for video_frame, final_prediction in zip(video_frames, final_predictions): utils.io.image.write( final_prediction, os.path.join(self.output_folder, self.image_prefix + video_frame + '.tif')) utils.io.image.write_np( np.stack(tracker.label_stack_list, axis=1), os.path.join(self.output_folder, 'label_stack.mha')) final_predictions_stacked = utils.sitk_image.accumulate( final_predictions) utils.io.image.write(final_predictions_stacked, os.path.join(self.output_folder, 'stacked.mha')) #utils.io.image.write(utils.sitk_image.accumulate(final_predictions_smoothed_2), os.path.join(self.output_folder, 'stacked_2.mha')) #utils.io.image.write(utils.sitk_image.accumulate(final_predictions_smoothed_4), os.path.join(self.output_folder, 'stacked_4.mha')) print(track_tuples) utils.io.text.save_list_csv(track_tuples, os.path.join(self.output_folder, self.track_file_name), delimiter=' ') def label_smooth(self, im, sigma): label_images, labels = utils.np_image.split_label_image_with_unknown_labels( im, dtype=np.float32) smoothed_label_images = utils.np_image.smooth_label_images( label_images, sigma=sigma, dtype=im.dtype) return utils.np_image.merge_label_images(smoothed_label_images, labels) def testx(self): label_stack = utils.sitk_np.sitk_to_np( utils.io.image.read(os.path.join(self.output_folder, 'label_stack.mha'), sitk_pixel_type=sitk.sitkVectorUInt16)) label_stack = np.transpose(label_stack, [0, 3, 1, 2]) tracker = EmbeddingTracker( coord_factors=self.coord_factors, stack_neighboring_slices=2, min_cluster_size=self.min_samples * 2, min_samples=self.min_samples, min_label_size_per_stack=self.min_samples, save_label_stack=True, image_ignore_border=self.border_size, parent_search_dilation_size=self.parent_dilation) tracker.set_label_stack(label_stack) video_frames = glob.glob(self.base_folder + '*.tif') video_frames = sorted([ os.path.splitext(os.path.basename(frame))[0][1:] for frame in video_frames ]) dataset_entry = self.dataset_val.get({'image_id': video_frames[0]}) datasources = dataset_entry['datasources'] input_image = datasources['image'] transformations = dataset_entry['transformations'] transformation = transformations['image'] datasources = dataset_entry['datasources'] input_image = datasources['image'] if self.sigma == 1: interpolator = 'label_gaussian' else: interpolator = 'nearest' tracker.finalize() track_tuples = tracker.track_tuples merged = tracker.stacked_label_image final_predictions = utils.sitk_image.transform_np_output_to_sitk_input( merged, output_spacing=None, channel_axis=0, input_image_sitk=input_image, transform=transformation, interpolator=interpolator, output_pixel_type=sitk.sitkUInt16) # final_predictions = [utils.sitk_np.np_to_sitk(np.squeeze(im), type=np.uint16) for im in np.split(merged, merged.shape[0], axis=0)] # final_predictions_smoothed_2 = [utils.sitk_image.apply_np_image_function(im, lambda x: self.label_smooth(x, sigma=2)) for im in final_predictions] if self.sigma > 1: final_predictions = [ utils.sitk_image.apply_np_image_function( im, lambda x: self.label_smooth(x, sigma=self.sigma)) for im in final_predictions ] for video_frame, final_prediction in zip(video_frames, final_predictions): utils.io.image.write( final_prediction, os.path.join(self.output_folder, self.image_prefix + video_frame + '.tif')) utils.io.image.write_np( np.stack(tracker.label_stack_list, axis=1), os.path.join(self.output_folder, 'label_stack.mha')) final_predictions_stacked = utils.sitk_image.accumulate( final_predictions) utils.io.image.write(final_predictions_stacked, os.path.join(self.output_folder, 'stacked.mha')) # utils.io.image.write(utils.sitk_image.accumulate(final_predictions_smoothed_2), os.path.join(self.output_folder, 'stacked_2.mha')) # utils.io.image.write(utils.sitk_image.accumulate(final_predictions_smoothed_4), os.path.join(self.output_folder, 'stacked_4.mha')) print(track_tuples) utils.io.text.save_list_csv(track_tuples, os.path.join(self.output_folder, self.track_file_name), delimiter=' ') def create_output_folder(self): create_directories(self.output_folder) def load_model(self): self.saver = tf.train.Saver() self.saver = tf.train.Saver() print('Restoring model ' + self.load_model_filename) self.saver.restore(self.sess, self.load_model_filename) def init_all(self): self.init_networks() self.load_model() self.create_output_folder() def run_test(self): self.init_all() print('Starting main test loop') self.test()
def __init__(self, dataset_name, input_image_folder, output_folder, model_file_name, num_embeddings, num_frames): self.dataset_name = dataset_name self.sess = tf.Session() self.coord = tf.train.Coordinator() self.image_prefix = 'mask' self.track_file_name = 'res_track.txt' self.save_all_embeddings = False self.save_all_input_images = False self.save_all_predictions = True self.hdbscan = False self.image_size, self.test_image_size = image_sizes_for_dataset_name( dataset_name) self.tiled_increment = [128, 128] self.instances_ignore_border = [32, 32] self.num_frames = num_frames self.data_format = 'channels_last' self.output_size = self.image_size self.num_embeddings = num_embeddings self.input_image_folder = input_image_folder self.output_folder = output_folder self.load_model_filename = model_file_name self.dataset_parameters = get_dataset_parameters(dataset_name) self.instance_image_creator_parameters = get_instance_image_creator_parameters( dataset_name) if self.hdbscan == True: if self.dataset_name == 'DIC-C2DH-HeLa': min_cluster_size = 1000 coord_factors = 0.02 if self.dataset_name == 'Fluo-N2DH-GOWT1': min_cluster_size = 1000 coord_factors = 0.001 if self.dataset_name == 'PhC-C2DH-U373': min_cluster_size = 500 coord_factors = 0.005 if self.dataset_name == 'Fluo-N2DL-HeLa': min_cluster_size = 25 coord_factors = 0.01 self.instance_image_creator_parameters[ 'min_cluster_size'] = min_cluster_size self.instance_image_creator_parameters[ 'coord_factors'] = coord_factors self.instance_image_creator_parameters['hdbscan'] = True self.instance_tracker_parameters = get_instance_tracker_parameters( dataset_name) self.dataset = Dataset( self.test_image_size, base_folder=self.input_image_folder, data_format=self.data_format, debug_image_folder=os.path.join(self.output_folder, 'input_images') if self.save_all_input_images else None, **self.dataset_parameters) self.dataset_val = self.dataset.dataset_val_single_frame() self.video_frames = glob(self.input_image_folder + '/*.tif') self.video_frames = sorted([ os.path.splitext(os.path.basename(frame))[0][1:] for frame in self.video_frames ]) if self.data_format == 'channels_first': self.channel_axis = 1 self.time_stack_axis = 1 else: self.channel_axis = 3 self.time_stack_axis = 0