def iterator(self, id_list_filename, random): """ Iterator used for iterating over the dataset. If not self.generate_single_vertebrae or generate_single_vertebrae_heatmap: use id_list_filename else: use image_id and landmark_id tuples for all valid_landmarks per image :param id_list_filename: The used id_list_filename of image_ids :param random: Shuffle if true. :return: IdListIterator used for image_id (and landmark_id) iteration. """ if self.generate_single_vertebrae or self.generate_single_vertebrae_heatmap: valid_landmarks = utils.io.text.load_dict_csv( self.valid_landmarks_file) def whole_list_postprocessing(id_list): new_id_list = [] for image_id in id_list: for landmark in valid_landmarks[image_id[0]]: new_id_list.append([image_id[0], landmark]) return new_id_list id_list_iterator = IdListIterator( id_list_filename, random, whole_list_postprocessing=whole_list_postprocessing, keys=['image_id', 'landmark_id'], name='iterator') else: id_list_iterator = IdListIterator(id_list_filename, random, keys=['image_id'], name='iterator') return id_list_iterator
def iterator(self, id_list_filename, random): """ Iterator used for iterating over the dataset. If not self.generate_single_vertebrae or generate_single_vertebrae_heatmap: use id_list_filename else: use image_id and landmark_id tuples for all valid_landmarks per image :param id_list_filename: The used id_list_filename of image_ids :param random: Shuffle if true. :return: IdListIterator used for image_id (and landmark_id) iteration. """ if self.generate_single_vertebrae or self.generate_single_vertebrae_heatmap: valid_landmarks = utils.io.text.load_dict_csv(self.valid_landmarks_file) def whole_list_postprocessing(id_list): new_id_list = [] for image_id in id_list: for landmark in valid_landmarks[image_id[0]]: new_id_list.append([image_id[0], landmark]) return new_id_list if not random and not self.resample_iterator: id_list_iterator = IdListIterator(id_list_filename, random, whole_list_postprocessing=whole_list_postprocessing, keys=['image_id', 'landmark_id'], name='iterator', use_shuffle=False) else: # 0-6: C1-C7 # 7-18: T1-12 # 19-24: L1-6 # 25: T13 def id_to_label_function(curr_id): landmark_id = int(curr_id[1]) if 0 <= landmark_id <= 6: return 'c' elif 7 <= landmark_id <= 18 or landmark_id == 25: return 't' elif 19 <= landmark_id <= 24: return 'l' return 'u' id_list_iterator = ResampleLabelsIdListIterator(id_list_filename, None, ['c', 't', 'l'], whole_list_postprocessing=whole_list_postprocessing, id_to_label_function=id_to_label_function, keys=['image_id', 'landmark_id'], name='iterator') else: id_list_iterator = IdListIterator(id_list_filename, random, keys=['image_id'], name='iterator', use_shuffle=False) return id_list_iterator
def dataset_val(self): """ Returns the validation dataset for videos. No random augmentation is performed. :return: The validation dataset. """ iterator_val = IdListIterator(self.val_id_list_file_name, random=False, keys=['video_id', 'frame_id']) sources = self.datasources() generator_sources = self.data_generator_sources() generators_val = self.data_generators(None) image_transformation = self.spatial_transformation() image_key = 'merged' if self.pad_image or self.crop_image_size is not None else 'image' dataset_val = ReferenceTransformationDataset( dim=self.dim, reference_datasource_keys={'image': image_key}, reference_transformation=image_transformation, datasources=sources, data_generators=generators_val, data_generator_sources=generator_sources, iterator=iterator_val, all_generators_post_processing=lambda x: self. all_generators_post_processing(x, False), debug_image_folder=os.path.join(self.debug_folder_prefix, 'debug_val') if self.save_debug_images else None, use_only_first_reference_datasource_entry=True) return dataset_val
def dataset_train_single_frame(self): """ Returns the training dataset for single frames. Random augmentation is performed. :return: The training dataset. """ iterator_train = IdListIterator(self.train_id_list_file_name, random=True, keys=['video_id', 'frame_id']) sources = self.datasources_single_frame() generator_sources = self.data_generator_sources() generators_train = self.data_generators_single_frame( self.postprocessing_random) image_transformation = self.spatial_transformation_augmented() image_key = 'merged' if self.pad_image or self.crop_image_size is not None else 'image' dataset_train = ReferenceTransformationDataset( dim=self.dim, reference_datasource_keys={'image': image_key}, reference_transformation=image_transformation, datasources=sources, data_generators=generators_train, data_generator_sources=generator_sources, iterator=iterator_train, all_generators_post_processing=lambda x: self. all_generators_post_processing(x, True), debug_image_folder=os.path.join(self.debug_folder_prefix, 'debug_train') if self.save_debug_images else None) return dataset_train
def dataset_val(self): """ Returns the validation dataset. No random augmentation is performed. :return: The validation dataset. """ iterator = IdListIterator(self.val_id_list_file_name, random=False, keys=['image_id']) data_sources = self.data_sources(iterator, False) if self.translate_by_random_factor: image_size = self.image_size[:2] + [None] else: image_size = self.image_size image_transformation = self.spatial_transformation( data_sources['image_datasource'], image_size) data_generators = self.data_generators( data_sources['image_datasource'], data_sources['landmarks_datasource'], image_transformation, self.intensity_postprocessing, image_size) dataset = GraphDataset( data_sources=list(data_sources.values()), data_generators=list(data_generators.values()), transformations=[image_transformation], iterator=iterator, debug_image_folder='debug_val' if self.save_debug_images else None) return dataset
def dataset_train(self): """ Returns the training dataset. Random augmentation is performed. :return: The training dataset. """ iterator = IdListIterator(self.train_file, random=True, keys=['image_id']) sources = self.datasources() generator_sources = self.data_generator_sources() generators = self.data_generators(self.postprocessing_random, self.split_labels) reference_transformation = self.spatial_transformation_augmented() return ReferenceTransformationDataset( dim=self.dim, reference_datasource_keys={ 'image': 'image', 'landmarks': 'landmarks' }, reference_transformation=reference_transformation, datasources=sources, data_generators=generators, data_generator_sources=generator_sources, iterator=iterator, debug_image_folder='debug_train' if self.save_debug_images else None)
def dataset_val(self): """ Returns the validation dataset. No random augmentation is performed. :return: The validation dataset. """ iterator = IdListIterator(self.test_file, random=False, keys=['image_id']) sources = self.datasources() generator_sources = self.data_generator_sources() generators = self.data_generators(self.postprocessing, self.split_labels) reference_transformation = self.spatial_transformation() if self.cv == 0: del sources['mask'] del generator_sources['mask'] del generators['mask'] return ReferenceTransformationDataset( dim=self.dim, reference_datasource_keys={ 'image': 'image', 'landmarks': 'landmarks' }, reference_transformation=reference_transformation, datasources=sources, data_generators=generators, data_generator_sources=generator_sources, iterator=iterator, debug_image_folder='debug_val' if self.save_debug_images else None)
def dataset_train_single_frame(self): """ Returns the training dataset for single frames. Random augmentation is performed. :return: The training dataset. """ iterator = IdListIterator(self.train_id_list_file_name, random=True, keys=['video_id', 'frame_id']) sources = self.datasources_single_frame(iterator) image_key = 'merged' if self.pad_image or self.crop_image_size is not None else 'image' image_transformation = self.spatial_transformation_augmented(sources[image_key]) generators = self.data_generators_single_frame(2, sources, image_transformation, self.postprocessing_random) final_generators = self.all_generators_post_processing(generators, False) return GraphDataset(data_generators=list(final_generators.values()), data_sources=list(sources.values()), transformations=[image_transformation], iterator=iterator, debug_image_folder='debug_train' if self.save_debug_images else None)
def dataset_val(self): """ Returns the validation dataset. No random augmentation is performed. :return: The validation dataset. """ iterator = IdListIterator(self.val_id_list_file_name, random=False, keys=['image_id']) data_sources = self.data_sources(False, iterator) image_transformation = self.spatial_transformation(data_sources) data_generators = self.data_generators(data_sources, image_transformation, self.intensity_postprocessing) return GraphDataset( data_generators=list(data_generators.values()), data_sources=list(data_sources.values()), transformations=[image_transformation], iterator=iterator, debug_image_folder='debug_val' if self.save_debug_images else None)
def dataset_train(self): """ Returns the training dataset for videos. Random augmentation is performed. :return: The training dataset. """ full_video_frame_list_image = VideoFrameList( self.video_frame_list_file_name, int(self.num_frames / 2), int(self.num_frames / 2) - 1, border_mode='valid', random_start=True, random_skip_probability=self.random_skip_probability) iterator_train = IdListIterator( self.train_id_list_file_name, random=True, keys=['video_id', 'frame_id'], postprocessing=lambda x: full_video_frame_list_image. get_id_dict_list(x['video_id'], x['frame_id'])) sources = self.datasources() generator_sources = self.data_generator_sources() generators_train = self.data_generators(self.postprocessing_random) image_transformation = self.spatial_transformation_augmented() image_key = 'merged' if self.pad_image or self.crop_image_size is not None else 'image' dataset_train = ReferenceTransformationDataset( dim=self.dim, reference_datasource_keys={'image': image_key}, reference_transformation=image_transformation, datasources=sources, data_generators=generators_train, data_generator_sources=generator_sources, iterator=iterator_train, all_generators_post_processing=lambda x: self. all_generators_post_processing(x, True), debug_image_folder=os.path.join(self.debug_folder_prefix, 'debug_train') if self.save_debug_images else None, use_only_first_reference_datasource_entry=True) return dataset_train
def dataset_val(self): """ Returns the validation dataset. No random augmentation is performed. :return: The validation dataset. """ data_sources = self.data_sources(False) data_generator_sources = self.data_generator_sources() data_generators = self.data_generators(self.intensity_postprocessing) image_transformation = self.spatial_transformation() iterator = IdListIterator(self.val_id_list_file_name, random=False, keys=['image_id']) dataset = ReferenceTransformationDataset(dim=self.dim, reference_datasource_keys={'image': 'image_datasource'}, reference_transformation=image_transformation, datasources=data_sources, data_generators=data_generators, data_generator_sources=data_generator_sources, iterator=iterator, debug_image_folder='debug_val' if self.save_debug_images else None) return dataset
def dataset_train(self): """ Returns the training dataset. Random augmentation is performed. :return: The training dataset. """ iterator = IdListIterator(self.train_file, random=True, keys=['image_id'], name='iterator') sources = self.datasources(iterator, True) reference_transformation = self.spatial_transformation_augmented( sources) generators = self.data_generators(sources, reference_transformation, self.postprocessing_random) return GraphDataset(data_generators=list(generators.values()), data_sources=list(sources.values()), transformations=[reference_transformation], iterator=iterator, debug_image_folder='debug_train' if self.save_debug_images else None)
def dataset_train(self): """ Returns the training dataset. Random augmentation is performed. :return: The training dataset. """ iterator = IdListIterator(self.train_id_list_file_name, random=True, keys=['image_id']) data_sources = self.data_sources(iterator, True) image_transformation = self.spatial_transformation_augmented( data_sources['image_datasource']) data_generators = self.data_generators( data_sources['image_datasource'], data_sources['landmarks_datasource'], image_transformation, self.intensity_postprocessing_augmented, self.image_size) dataset = GraphDataset(data_sources=list(data_sources.values()), data_generators=list(data_generators.values()), transformations=[image_transformation], iterator=iterator, debug_image_folder='debug_train' if self.save_debug_images else None) return dataset
def dataset_val(self): """ Returns the validation dataset for videos. No random augmentation is performed. :return: The validation dataset. """ dim = 3 full_video_frame_list_image = VideoFrameList(self.video_frame_list_file_name, int(self.num_frames / 2), int(self.num_frames / 2) - 1, border_mode='valid', random_start=False, random_skip_probability=0.0) iterator = IdListIterator(self.val_id_list_file_name, random=False, keys=['video_id', 'frame_id']) iterator_postprocessing = LambdaNode(lambda x: full_video_frame_list_image.get_id_dict_list(x['video_id'], x['frame_id']), parents=[iterator]) sources = self.datasources(iterator_postprocessing) image_key = 'merged' if self.pad_image or self.crop_image_size is not None else 'image' image_transformation = self.spatial_transformation_volumetric(sources[image_key]) generators = self.data_generators(dim, sources, image_transformation, None) final_generators = self.all_generators_post_processing(generators, False) return GraphDataset(data_generators=list(final_generators.values()), data_sources=list(sources.values()), transformations=[image_transformation], iterator=iterator, debug_image_folder='debug_train' if self.save_debug_images else None)
def dataset_train(self): """ Returns the training dataset for videos. Random augmentation is performed. :return: The training dataset. """ dim = 3 full_video_frame_list_image = VideoFrameList(self.video_frame_list_file_name, self.num_frames - 1, 0, border_mode='valid', random_start=True, random_skip_probability=self.random_skip_probability) iterator = IdListIterator(self.train_id_list_file_name, random=True, keys=['video_id', 'frame_id'], postprocessing=lambda x: full_video_frame_list_image.get_id_dict_list(x['video_id'], x['frame_id'])) sources = self.datasources(iterator) image_key = 'merged' if self.pad_image or self.crop_image_size is not None else 'image' image_transformation = self.spatial_transformation_volumetric_augmented(sources[image_key]) generators = self.data_generators(dim, sources, image_transformation, self.postprocessing_random) final_generators = self.all_generators_post_processing(generators, False) return GraphDataset(data_generators=list(final_generators.values()), # data_sources=list(sources.values()), # transformations=[image_transformation], iterator=iterator, debug_image_folder='debug_train' if self.save_debug_images else None)
def dataset_val(self): """ Returns the validation dataset. No random augmentation is performed. :return: The validation dataset. """ iterator = IdListIterator(self.test_file, random=False, keys=['image_id'], name='iterator') sources = self.datasources(iterator, False) reference_transformation = self.spatial_transformation(sources) generators = self.data_generators(sources, reference_transformation, self.postprocessing) if self.cv == 0: del sources['landmarks'] del generators['landmarks'] return GraphDataset( data_generators=list(generators.values()), data_sources=list(sources.values()), transformations=[reference_transformation], iterator=iterator, debug_image_folder='debug_val' if self.save_debug_images else None)
class MainLoop(MainLoopBase): def __init__(self, dataset_name): super().__init__() self.dataset_name = dataset_name self.batch_size = 10 self.learning_rate = 0.0001 self.learning_rates = [self.learning_rate, self.learning_rate * 0.1] self.learning_rate_boundaries = [20000] self.max_iter = 40000 self.test_iter = 5000 self.disp_iter = 100 self.snapshot_iter = self.test_iter self.test_initialization = True self.current_iter = 0 self.reg_constant = 0.00001 self.use_batch_norm = False self.invert_transformation = False self.use_pyro_dataset = False self.save_debug_images = False self.image_size = [256, 256] self.output_size = self.image_size self.data_format = 'channels_first' self.num_frames = 10 self.embeddings_dim = 16 self.test_on_challenge_data = True self.challenge_base_folder = '../celltrackingchallenge/' self.output_base_folder = '/media1/experiments/cell_tracking/miccai2018_segmentation/' + self.dataset_name self.training_base_folder = os.path.join(self.challenge_base_folder, 'trainingdataset/', self.dataset_name) self.testing_base_folder = os.path.join(self.challenge_base_folder, 'challengedataset/', self.dataset_name) self.output_folder = os.path.join(self.output_base_folder, self.output_folder_timestamp()) self.embedding_factors = {'bac': 1, 'tra': 1} instance_image_radius_factors = { 'DIC-C2DH-HeLa': 0.2, 'Fluo-C2DL-MSC': 0.6, 'Fluo-N2DH-GOWT1': 0.2, 'Fluo-N2DH-SIM+': 0.2, 'Fluo-N2DL-HeLa': 0.1, 'PhC-C2DH-U373': 0.2, 'PhC-C2DL-PSC': 0.1 } instance_image_radius_factor = instance_image_radius_factors[ self.dataset_name] label_gaussian_blur_sigmas = { 'DIC-C2DH-HeLa': 2.0, 'Fluo-C2DL-MSC': 0, 'Fluo-N2DH-GOWT1': 0, 'Fluo-N2DH-SIM+': 0, 'Fluo-N2DL-HeLa': 0, 'PhC-C2DH-U373': 0, 'PhC-C2DL-PSC': 0 } label_gaussian_blur_sigma = label_gaussian_blur_sigmas[ self.dataset_name] crop_image_sizes = { 'DIC-C2DH-HeLa': None, 'Fluo-C2DL-MSC': [20, 20], 'Fluo-N2DH-GOWT1': None, 'Fluo-N2DH-SIM+': None, 'Fluo-N2DL-HeLa': None, 'PhC-C2DH-U373': None, 'PhC-C2DL-PSC': None } crop_image_size = crop_image_sizes[self.dataset_name] normalization_consideration_factors = { 'DIC-C2DH-HeLa': (0.2, 0.1), 'Fluo-C2DL-MSC': (0.2, 0.01), 'Fluo-N2DH-GOWT1': (0.2, 0.1), 'Fluo-N2DH-SIM+': (0.2, 0.01), 'Fluo-N2DL-HeLa': (0.2, 0.1), 'PhC-C2DH-U373': (0.2, 0.1), 'PhC-C2DL-PSC': (0.2, 0.1) } normalization_consideration_factor = normalization_consideration_factors[ self.dataset_name] pad_images = { 'DIC-C2DH-HeLa': True, 'Fluo-C2DL-MSC': False, 'Fluo-N2DH-GOWT1': True, 'Fluo-N2DH-SIM+': True, 'Fluo-N2DL-HeLa': True, 'PhC-C2DH-U373': False, 'PhC-C2DL-PSC': True } pad_image = pad_images[self.dataset_name] self.bitwise_instance_image = False self.dataset = Dataset( self.image_size, num_frames=1, base_folder=self.training_base_folder, data_format=self.data_format, save_debug_images=True, instance_image_radius_factor=instance_image_radius_factor, max_num_instances=32, train_id_file='tra_all.csv', val_id_file='tra_all.csv', image_gaussian_blur_sigma=1.0, label_gaussian_blur_sigma=label_gaussian_blur_sigma, normalization_consideration_factors= normalization_consideration_factor, pad_image=pad_image, crop_image_size=crop_image_size) self.dataset_train = self.dataset.dataset_train_single_frame() self.dataset_val = self.dataset.dataset_val_single_frame() self.dataset_train.get_next() self.setup_base_folder = os.path.join(self.training_base_folder, 'setup') self.video_frame_list_file_name = os.path.join(self.setup_base_folder, 'frames.csv') self.iterator_val = IdListIterator(os.path.join( self.setup_base_folder, 'video_only_all.csv'), random=False, keys=['video_id']) def initNetworks(self): network_image_size = self.image_size if self.data_format == 'channels_first': data_generator_entries = OrderedDict([ ('image', [1] + network_image_size), ('instances_merged', [None] + network_image_size), ('instances_bac', [None] + network_image_size) ]) else: data_generator_entries = OrderedDict([ ('image', network_image_size + [1]), ('instances_merged', network_image_size + [None]), ('instances_bac', network_image_size + [None]) ]) # create model with shared weights between train and val training_net = tf.make_template('net', network) loss_function = lambda prediction, groundtruth: cosine_embedding_per_instance_loss( prediction, groundtruth, data_format=self.data_format, normalize=True, term_1_squared=True, l=1.0) # build train graph self.train_queue = DataGeneratorPadding(self.dataset_train, self.coord, data_generator_entries, batch_size=self.batch_size) data, tracking, instances_bac = self.train_queue.dequeue() embeddings_0, embeddings_1 = training_net( data, num_outputs_embedding=self.embeddings_dim, is_training=True, data_format=self.data_format) # losses background_embedding_loss_0 = self.embedding_factors[ 'bac'] * loss_function(embeddings_0, instances_bac) tra_embedding_loss_0 = self.embedding_factors['tra'] * loss_function( embeddings_0, tracking) background_embedding_loss_1 = self.embedding_factors[ 'bac'] * loss_function(embeddings_1, instances_bac) tra_embedding_loss_1 = self.embedding_factors['tra'] * loss_function( embeddings_1, tracking) self.loss_net = background_embedding_loss_0 + tra_embedding_loss_0 + background_embedding_loss_1 + tra_embedding_loss_1 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): if self.reg_constant > 0: reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) self.loss_reg = self.reg_constant * tf.add_n(reg_losses) self.loss = self.loss_net + self.loss_reg else: self.loss_reg = 0 self.loss = self.loss_net self.train_losses = OrderedDict([ ('loss_bac_emb_0', background_embedding_loss_0), ('loss_tra_emb_0', tra_embedding_loss_0), ('loss_bac_emb_1', background_embedding_loss_1), ('loss_tra_emb_1', tra_embedding_loss_1), ('loss_reg', self.loss_reg) ]) # solver global_step = tf.Variable(self.current_iter) learning_rate = tf.train.piecewise_constant( global_step, self.learning_rate_boundaries, self.learning_rates) self.optimizer = tf.contrib.opt.NadamOptimizer( learning_rate=learning_rate).minimize(self.loss, global_step=global_step) # build val graph val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders( data_generator_entries, shape_prefix=[1]) self.data_val = val_placeholders['image'] self.tracking_val = val_placeholders['instances_merged'] self.instances_bac_val = val_placeholders['instances_bac'] self.embeddings_0_val, self.embeddings_1_val = training_net( self.data_val, num_outputs_embedding=self.embeddings_dim, is_training=False, data_format=self.data_format) self.embeddings_normalized_1_val = tf.nn.l2_normalize( self.embeddings_1_val, dim=1) # losses self.background_embedding_loss_0_val = self.embedding_factors[ 'bac'] * loss_function(self.embeddings_0_val, self.instances_bac_val) self.tra_embedding_loss_0_val = self.embedding_factors[ 'tra'] * loss_function(self.embeddings_0_val, self.tracking_val) self.background_embedding_loss_1_val = self.embedding_factors[ 'bac'] * loss_function(self.embeddings_1_val, self.instances_bac_val) self.tra_embedding_loss_1_val = self.embedding_factors[ 'tra'] * loss_function(self.embeddings_1_val, self.tracking_val) self.loss_val = self.background_embedding_loss_0_val + self.tra_embedding_loss_0_val # + self.background_embedding_loss_1_val + self.tra_embedding_loss_1_val self.val_losses = OrderedDict([ ('loss_bac_emb_0', self.background_embedding_loss_0_val), ('loss_tra_emb_0', self.tra_embedding_loss_0_val), ('loss_bac_emb_1', self.background_embedding_loss_1_val), ('loss_tra_emb_1', self.tra_embedding_loss_1_val), ('loss_reg', self.loss_reg) ]) def test(self): print('Testing...') channel_axis = 0 if self.data_format == 'channels_last': channel_axis = 3 interpolator = 'cubic' video_id_frames = utils.io.text.load_dict_csv( self.video_frame_list_file_name) num_entries = self.iterator_val.num_entries() for current_entry_index in range(num_entries): video_id = self.iterator_val.get_next_id()['video_id'] video_frames = video_id_frames[video_id] current_embeddings = [] current_embeddings_softmax = [] for video_frame in video_frames: current_id = video_id + '_' + video_frame dataset_entry = self.dataset_val.get({ 'video_id': video_id, 'frame_id': video_frame, 'unique_id': current_id }) datasources = dataset_entry['datasources'] generators = dataset_entry['generators'] transformations = dataset_entry['transformations'] feed_dict = { self.data_val: np.expand_dims(generators['image'], axis=0), self.tracking_val: np.expand_dims(generators['instances_merged'], axis=0), self.instances_bac_val: np.expand_dims(generators['instances_bac'], axis=0) } run_tuple = self.sess.run( (self.embeddings_1_val, self.embeddings_normalized_1_val, self.loss_val) + self.val_loss_aggregator.get_update_ops(), feed_dict=feed_dict) embeddings = np.squeeze(run_tuple[0], axis=0) embeddings_softmax = np.squeeze(run_tuple[1], axis=0) current_embeddings.append(embeddings) current_embeddings_softmax.append(embeddings_softmax) if self.invert_transformation: input_sitk = datasources['tra'] transformation = transformations['image'] _ = utils.sitk_image.transform_np_output_to_sitk_input( output_image=embeddings, output_spacing=None, channel_axis=channel_axis, input_image_sitk=input_sitk, transform=transformation, interpolator=interpolator) current_embeddings_softmax = np.stack(current_embeddings_softmax, axis=1) utils.io.image.write_np( current_embeddings_softmax, os.path.join( self.output_folder, 'out/iter_' + str(self.current_iter) + '/' + video_id + '_embeddings_softmax.mha')) tensorflow_train.utils.tensorflow_util.print_progress_bar( current_entry_index, num_entries, prefix='Testing ', suffix=' complete') # finalize loss values self.val_loss_aggregator.finalize(self.current_iter)
def __init__(self, dataset_name): super().__init__() self.dataset_name = dataset_name self.batch_size = 10 self.learning_rate = 0.0001 self.learning_rates = [self.learning_rate, self.learning_rate * 0.1] self.learning_rate_boundaries = [20000] self.max_iter = 40000 self.test_iter = 5000 self.disp_iter = 100 self.snapshot_iter = self.test_iter self.test_initialization = True self.current_iter = 0 self.reg_constant = 0.00001 self.use_batch_norm = False self.invert_transformation = False self.use_pyro_dataset = False self.save_debug_images = False self.image_size = [256, 256] self.output_size = self.image_size self.data_format = 'channels_first' self.num_frames = 10 self.embeddings_dim = 16 self.test_on_challenge_data = True self.challenge_base_folder = '../celltrackingchallenge/' self.output_base_folder = '/media1/experiments/cell_tracking/miccai2018_segmentation/' + self.dataset_name self.training_base_folder = os.path.join(self.challenge_base_folder, 'trainingdataset/', self.dataset_name) self.testing_base_folder = os.path.join(self.challenge_base_folder, 'challengedataset/', self.dataset_name) self.output_folder = os.path.join(self.output_base_folder, self.output_folder_timestamp()) self.embedding_factors = {'bac': 1, 'tra': 1} instance_image_radius_factors = { 'DIC-C2DH-HeLa': 0.2, 'Fluo-C2DL-MSC': 0.6, 'Fluo-N2DH-GOWT1': 0.2, 'Fluo-N2DH-SIM+': 0.2, 'Fluo-N2DL-HeLa': 0.1, 'PhC-C2DH-U373': 0.2, 'PhC-C2DL-PSC': 0.1 } instance_image_radius_factor = instance_image_radius_factors[ self.dataset_name] label_gaussian_blur_sigmas = { 'DIC-C2DH-HeLa': 2.0, 'Fluo-C2DL-MSC': 0, 'Fluo-N2DH-GOWT1': 0, 'Fluo-N2DH-SIM+': 0, 'Fluo-N2DL-HeLa': 0, 'PhC-C2DH-U373': 0, 'PhC-C2DL-PSC': 0 } label_gaussian_blur_sigma = label_gaussian_blur_sigmas[ self.dataset_name] crop_image_sizes = { 'DIC-C2DH-HeLa': None, 'Fluo-C2DL-MSC': [20, 20], 'Fluo-N2DH-GOWT1': None, 'Fluo-N2DH-SIM+': None, 'Fluo-N2DL-HeLa': None, 'PhC-C2DH-U373': None, 'PhC-C2DL-PSC': None } crop_image_size = crop_image_sizes[self.dataset_name] normalization_consideration_factors = { 'DIC-C2DH-HeLa': (0.2, 0.1), 'Fluo-C2DL-MSC': (0.2, 0.01), 'Fluo-N2DH-GOWT1': (0.2, 0.1), 'Fluo-N2DH-SIM+': (0.2, 0.01), 'Fluo-N2DL-HeLa': (0.2, 0.1), 'PhC-C2DH-U373': (0.2, 0.1), 'PhC-C2DL-PSC': (0.2, 0.1) } normalization_consideration_factor = normalization_consideration_factors[ self.dataset_name] pad_images = { 'DIC-C2DH-HeLa': True, 'Fluo-C2DL-MSC': False, 'Fluo-N2DH-GOWT1': True, 'Fluo-N2DH-SIM+': True, 'Fluo-N2DL-HeLa': True, 'PhC-C2DH-U373': False, 'PhC-C2DL-PSC': True } pad_image = pad_images[self.dataset_name] self.bitwise_instance_image = False self.dataset = Dataset( self.image_size, num_frames=1, base_folder=self.training_base_folder, data_format=self.data_format, save_debug_images=True, instance_image_radius_factor=instance_image_radius_factor, max_num_instances=32, train_id_file='tra_all.csv', val_id_file='tra_all.csv', image_gaussian_blur_sigma=1.0, label_gaussian_blur_sigma=label_gaussian_blur_sigma, normalization_consideration_factors= normalization_consideration_factor, pad_image=pad_image, crop_image_size=crop_image_size) self.dataset_train = self.dataset.dataset_train_single_frame() self.dataset_val = self.dataset.dataset_val_single_frame() self.dataset_train.get_next() self.setup_base_folder = os.path.join(self.training_base_folder, 'setup') self.video_frame_list_file_name = os.path.join(self.setup_base_folder, 'frames.csv') self.iterator_val = IdListIterator(os.path.join( self.setup_base_folder, 'video_only_all.csv'), random=False, keys=['video_id'])
class MainLoop(MainLoopBase): def __init__(self, dataset_name): super().__init__() self.dataset_name = dataset_name self.batch_size = 1 self.learning_rate = 0.0001 self.learning_rates = [self.learning_rate, self.learning_rate * 0.1] self.learning_rate_boundaries = [20000] self.max_iter = 40000 self.test_iter = 5000 self.disp_iter = 100 self.snapshot_iter = self.test_iter self.test_initialization = True self.current_iter = 0 self.reg_constant = 0.00001 self.use_batch_norm = False self.invert_transformation = False self.use_pyro_dataset = False self.save_debug_images = False self.image_size = [256, 256] self.output_size = self.image_size self.data_format = 'channels_first' self.num_frames = 10 self.embeddings_dim = 16 self.test_on_challenge_data = True self.challenge_base_folder = '../celltrackingchallenge/' self.output_base_folder = '/media1/experiments/cell_tracking/miccai2018/' + self.dataset_name self.training_base_folder = os.path.join(self.challenge_base_folder, 'trainingdataset/', self.dataset_name) self.testing_base_folder = os.path.join(self.challenge_base_folder, 'challengedataset/', self.dataset_name) self.output_folder = os.path.join(self.output_base_folder, self.output_folder_timestamp()) self.embedding_factors = {'bac': 1, 'tra': 1} if self.test_on_challenge_data: self.train_id_file = 'tra_all.csv' self.val_id_file = 'tra_all.csv' else: self.train_id_file = 'tra_train.csv' self.val_id_file = 'tra_val.csv' instance_image_radius_factors = { 'DIC-C2DH-HeLa': 0.2, 'Fluo-C2DL-MSC': 0.6, 'Fluo-N2DH-GOWT1': 0.2, 'Fluo-N2DH-SIM+': 0.2, 'Fluo-N2DL-HeLa': 0.1, 'PhC-C2DH-U373': 0.2, 'PhC-C2DL-PSC': 0.1 } instance_image_radius_factor = instance_image_radius_factors[ self.dataset_name] label_gaussian_blur_sigmas = { 'DIC-C2DH-HeLa': 2.0, 'Fluo-C2DL-MSC': 0, 'Fluo-N2DH-GOWT1': 0, 'Fluo-N2DH-SIM+': 0, 'Fluo-N2DL-HeLa': 0, 'PhC-C2DH-U373': 0, 'PhC-C2DL-PSC': 0 } label_gaussian_blur_sigma = label_gaussian_blur_sigmas[ self.dataset_name] crop_image_sizes = { 'DIC-C2DH-HeLa': None, 'Fluo-C2DL-MSC': [20, 20], 'Fluo-N2DH-GOWT1': None, 'Fluo-N2DH-SIM+': None, 'Fluo-N2DL-HeLa': None, 'PhC-C2DH-U373': None, 'PhC-C2DL-PSC': None } crop_image_size = crop_image_sizes[self.dataset_name] normalization_consideration_factors = { 'DIC-C2DH-HeLa': (0.2, 0.1), 'Fluo-C2DL-MSC': (0.2, 0.01), 'Fluo-N2DH-GOWT1': (0.2, 0.1), 'Fluo-N2DH-SIM+': (0.2, 0.01), 'Fluo-N2DL-HeLa': (0.2, 0.1), 'PhC-C2DH-U373': (0.2, 0.1), 'PhC-C2DL-PSC': (0.2, 0.1) } normalization_consideration_factor = normalization_consideration_factors[ self.dataset_name] pad_images = { 'DIC-C2DH-HeLa': True, 'Fluo-C2DL-MSC': False, 'Fluo-N2DH-GOWT1': True, 'Fluo-N2DH-SIM+': True, 'Fluo-N2DL-HeLa': True, 'PhC-C2DH-U373': False, 'PhC-C2DL-PSC': True } pad_image = pad_images[self.dataset_name] self.dataset = Dataset( self.image_size, self.num_frames, base_folder=self.training_base_folder, data_format=self.data_format, save_debug_images=self.save_debug_images, instance_image_radius_factor=instance_image_radius_factor, max_num_instances=16, train_id_file=self.train_id_file, val_id_file=self.val_id_file, image_gaussian_blur_sigma=2.0, label_gaussian_blur_sigma=label_gaussian_blur_sigma, normalization_consideration_factors= normalization_consideration_factor, pad_image=pad_image, crop_image_size=crop_image_size) self.dataset_train = self.dataset.dataset_train() self.dataset_train.get_next() if self.test_on_challenge_data: dataset = Dataset( self.image_size, self.num_frames, base_folder=self.testing_base_folder, data_format=self.data_format, save_debug_images=self.save_debug_images, instance_image_radius_factor=instance_image_radius_factor, max_num_instances=16, train_id_file=self.train_id_file, val_id_file=self.val_id_file, image_gaussian_blur_sigma=2.0, label_gaussian_blur_sigma=label_gaussian_blur_sigma, pad_image=False, crop_image_size=None, load_merged=False, load_has_complete_seg=False, load_seg_loss_mask=False, create_instances_bac=False, create_instances_merged=False) self.dataset_val = dataset.dataset_val_single_frame() self.setup_base_folder = os.path.join(self.testing_base_folder, 'setup') self.video_frame_list_file_name = os.path.join( self.setup_base_folder, 'frames.csv') self.iterator_val = IdListIterator(os.path.join( self.setup_base_folder, 'video_only_all.csv'), random=False, keys=['video_id']) else: self.dataset_val = self.dataset.dataset_val_single_frame() self.setup_base_folder = os.path.join(self.training_base_folder, 'setup') self.video_frame_list_file_name = os.path.join( self.setup_base_folder, 'frames.csv') self.iterator_val = IdListIterator(os.path.join( self.setup_base_folder, 'video_only_all.csv'), random=False, keys=['video_id']) self.files_to_copy = ['dataset.py', 'network.py', 'main.py'] def initNetworks(self): network_image_size = self.image_size network_output_size = self.output_size num_instances = None if self.data_format == 'channels_first': data_generator_entries = OrderedDict([ ('image', [1, self.num_frames] + network_image_size), ('instances_merged', [num_instances, self.num_frames] + network_output_size), ('instances_bac', [1, self.num_frames] + network_output_size) ]) data_generator_entries_single_frame = OrderedDict([ ('image', [1] + network_image_size), ('instances_merged', [num_instances] + network_output_size), ('instances_bac', [1] + network_output_size) ]) else: data_generator_entries = OrderedDict([ ('image', [self.num_frames] + network_image_size + [1]), ('instances_merged', [self.num_frames] + network_output_size + [num_instances]), ('instances_bac', [self.num_frames] + network_output_size + [1]) ]) data_generator_entries_single_frame = OrderedDict([ ('image', network_image_size + [1]), ('instances_merged', network_output_size + [num_instances]), ('instances_bac', network_output_size + [1]) ]) # create model with shared weights between train and val lstm_net = network training_net = tf.make_template('net', lstm_net) loss_function = lambda prediction, groundtruth: cosine_embedding_per_instance_loss( prediction, groundtruth, data_format=self.data_format, normalize=True, term_1_squared=True, l=1.0) # build train graph self.train_queue = DataGeneratorPadding(self.dataset_train, self.coord, data_generator_entries, batch_size=self.batch_size, queue_size=64) data, tracking, instances_bac = self.train_queue.dequeue() embeddings, embeddings_2 = training_net( data, num_outputs_embedding=self.embeddings_dim, is_training=True, data_format=self.data_format) # losses, first and second hourglass with tf.variable_scope('loss'): tracking_embedding_loss = self.embedding_factors[ 'tra'] * loss_function(embeddings, tracking) bac_embedding_loss = self.embedding_factors['bac'] * loss_function( embeddings, instances_bac) with tf.variable_scope('loss_2'): tracking_embedding_loss_2 = self.embedding_factors[ 'tra'] * loss_function(embeddings_2, tracking) bac_embedding_loss_2 = self.embedding_factors[ 'bac'] * loss_function(embeddings_2, instances_bac) self.loss_net = tracking_embedding_loss + bac_embedding_loss + tracking_embedding_loss_2 + bac_embedding_loss_2 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): if self.reg_constant > 0: regularization_variables = [] for tf_var in tf.trainable_variables(): if 'kernel' in tf_var.name: regularization_variables.append(tf.nn.l2_loss(tf_var)) self.loss_reg = self.reg_constant * tf.add_n( regularization_variables) self.loss = self.loss_net + self.loss_reg else: self.loss = self.loss_net self.train_losses = OrderedDict([ ('loss_tra_emb', tracking_embedding_loss), ('loss_bac_emb', bac_embedding_loss), ('loss_reg', self.loss_reg), ('loss_tra_emb_2', tracking_embedding_loss_2), ('loss_bac_emb_2', bac_embedding_loss_2) ]) # solver global_step = tf.Variable(self.current_iter) learning_rate = tf.train.piecewise_constant( global_step, self.learning_rate_boundaries, self.learning_rates) self.optimizer = tf.contrib.opt.NadamOptimizer( learning_rate=learning_rate).minimize(self.loss, global_step=global_step) # initialize variables self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.local_variables_initializer()) print('Variables') for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): print(i) # build val graph val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders( data_generator_entries_single_frame, shape_prefix=[1]) self.data_val = val_placeholders['image'] self.tracking_val = val_placeholders['instances_merged'] self.instances_bac_val = val_placeholders['instances_bac'] with tf.variable_scope('net/rnn', reuse=True): self.embeddings_val, self.embeddings_2_val, self.lstm_input_states_val, self.lstm_output_states_val = network_single_frame_with_lstm_states( self.data_val, num_outputs_embedding=self.embeddings_dim, data_format=self.data_format) self.embeddings_normalized_val = tf.nn.l2_normalize( self.embeddings_val, dim=1) self.embeddings_normalized_2_val = tf.nn.l2_normalize( self.embeddings_2_val, dim=1) with tf.variable_scope('loss'): self.tracking_embedding_loss_val = self.embedding_factors[ 'tra'] * loss_function(self.embeddings_val, self.tracking_val) self.bac_embedding_loss_val = self.embedding_factors[ 'bac'] * loss_function(self.embeddings_val, self.instances_bac_val) with tf.variable_scope('loss_2'): self.tracking_embedding_loss_2_val = self.embedding_factors[ 'tra'] * loss_function(self.embeddings_2_val, self.tracking_val) self.bac_embedding_loss_2_val = self.embedding_factors[ 'bac'] * loss_function(self.embeddings_2_val, self.instances_bac_val) self.loss_val = self.tracking_embedding_loss_val + self.bac_embedding_loss_val + self.tracking_embedding_loss_2_val + self.bac_embedding_loss_2_val self.val_losses = OrderedDict([ ('loss_tra_emb', self.tracking_embedding_loss_val), ('loss_bac_emb', self.bac_embedding_loss_val), ('loss_reg', self.loss_reg), ('loss_tra_emb_2', self.tracking_embedding_loss_2_val), ('loss_bac_emb_2', self.bac_embedding_loss_2_val) ]) def test(self): if self.test_on_challenge_data: self.test_test() else: self.test_val() def test_val(self): print('Testing val dataset...') video_id_frames = utils.io.text.load_dict_csv( self.video_frame_list_file_name) channel_axis = 0 if self.data_format == 'channels_last': channel_axis = 3 num_entries = self.iterator_val.num_entries() for current_entry_index in range(num_entries): video_id = self.iterator_val.get_next_id()['video_id'] # print(video_id) video_frames = video_id_frames[video_id] # first lstm round first = True current_predictions = [] current_predictions_2 = [] for video_frame in video_frames: current_id = video_id + '_' + video_frame dataset_entry = self.dataset_val.get({ 'video_id': video_id, 'frame_id': video_frame, 'unique_id': current_id }) datasources = dataset_entry['datasources'] generators = dataset_entry['generators'] transformations = dataset_entry['transformations'] feed_dict = { self.data_val: np.expand_dims(generators['image'], axis=0), self.tracking_val: np.expand_dims(generators['instances_merged'], axis=0), self.instances_bac_val: np.expand_dims(generators['instances_bac'], axis=0) } # run loss and update loss accumulators if not first: for i in range(len(self.lstm_input_states_val)): feed_dict[self.lstm_input_states_val[ i]] = current_lstm_states[i] run_tuple = self.sess.run( [ self.loss_val, self.embeddings_normalized_val, self.embeddings_normalized_2_val ] + list(self.lstm_output_states_val) + list(self.val_loss_aggregator.get_update_ops()), feed_dict=feed_dict) embeddings_softmax = np.squeeze(run_tuple[1], axis=0) embeddings_softmax_2 = np.squeeze(run_tuple[2], axis=0) current_lstm_states = run_tuple[ 3:-len(self.val_loss_aggregator.get_update_ops())] current_predictions.append(embeddings_softmax) current_predictions_2.append(embeddings_softmax_2) first = False prediction = np.stack(current_predictions, axis=1) current_predictions = [] utils.io.image.write_np( prediction, os.path.join( self.output_folder, 'out_first/iter_' + str(self.current_iter) + '/' + video_id + '_embeddings.mha')) prediction = None prediction_2 = np.stack(current_predictions_2, axis=1) current_predictions_2 = [] utils.io.image.write_np( prediction_2, os.path.join( self.output_folder, 'out_first/iter_' + str(self.current_iter) + '/' + video_id + '_embeddings_2.mha')) prediction_2 = None tensorflow_train.utils.tensorflow_util.print_progress_bar( current_entry_index, num_entries, prefix='Testing ', suffix=' complete') # finalize loss values self.val_loss_aggregator.finalize(self.current_iter) def test_test(self): print('Testing test dataset...') video_id_frames = utils.io.text.load_dict_csv( self.video_frame_list_file_name) channel_axis = 0 if self.data_format == 'channels_last': channel_axis = 3 num_entries = self.iterator_val.num_entries() for current_entry_index in range(num_entries): video_id = self.iterator_val.get_next_id()['video_id'] video_frames = video_id_frames[video_id] # first lstm round first = True current_predictions = [] current_predictions_2 = [] current_images = [] for video_frame in video_frames: current_id = video_id + '_' + video_frame dataset_entry = self.dataset_val.get({ 'video_id': video_id, 'frame_id': video_frame, 'unique_id': current_id }) datasources = dataset_entry['datasources'] generators = dataset_entry['generators'] transformations = dataset_entry['transformations'] feed_dict = { self.data_val: np.expand_dims(generators['image'], axis=0) } # run loss and update loss accumulators if not first: for i in range(len(self.lstm_input_states_val)): feed_dict[self.lstm_input_states_val[ i]] = current_lstm_states[i] run_tuple = self.sess.run([ self.embeddings_normalized_val, self.embeddings_normalized_2_val ] + list(self.lstm_output_states_val), feed_dict=feed_dict) # print(iv[0].decode()) embeddings_softmax = np.squeeze(run_tuple[0], axis=0) embeddings_softmax_2 = np.squeeze(run_tuple[1], axis=0) current_lstm_states = run_tuple[2:] current_predictions.append(embeddings_softmax) current_predictions_2.append(embeddings_softmax_2) current_images.append(generators['image']) #current_instances.append(instance_segmentation_test.get_instances_cosine_kmeans_2d(embeddings_softmax)) first = False prediction = np.stack(current_predictions, axis=1) current_predictions = [] utils.io.image.write_np( prediction, os.path.join( self.output_folder, 'out_first/iter_' + str(self.current_iter) + '/' + video_id + '_embeddings.mha')) prediction = None prediction_2 = np.stack(current_predictions_2, axis=1) current_predictions_2 = [] utils.io.image.write_np( prediction_2, os.path.join( self.output_folder, 'out_first/iter_' + str(self.current_iter) + '/' + video_id + '_embeddings_2.mha')) prediction_2 = None images = np.stack(current_images, axis=1) current_images = [] utils.io.image.write_np( images, os.path.join( self.output_folder, 'out_first/iter_' + str(self.current_iter) + '/' + video_id + '_image.mha')) images = None tensorflow_train.utils.tensorflow_util.print_progress_bar( current_entry_index, num_entries, prefix='Testing ', suffix=' complete')
def test_folder(self, base_folder, dataset, name, update_loss): """ Test dataset folder. Creates embeddings and also performs instance segmentation (if parameters are set in __init__) :param base_folder: The base folder of the images to test. :param dataset: The dataset used for data preprocessing. :param name: The name of the dataset. :param update_loss: If true, update loss values. """ setup_base_folder = os.path.join(base_folder, 'setup') video_frame_list_file_name = os.path.join(setup_base_folder, 'frames.csv') iterator = IdListIterator(os.path.join(setup_base_folder, 'video_only_all.csv'), random=False, keys=['video_id']) video_id_frames = utils.io.text.load_dict_csv( video_frame_list_file_name) num_entries = iterator.num_entries() for current_entry_index in range(num_entries): video_id = iterator.get_next_id()['video_id'] video_frames_all = video_id_frames[video_id] frame_index = 0 instance_tracker = InstanceTracker( **self.instance_tracker_parameters) current_all_embeddings_cropped = [ [] for _ in range(len(self.embeddings_cropped_val)) ] for j in range(len(video_frames_all) - self.num_frames + 1): print('Processing frame', j) current_lstm_states_cropped = [] video_frames = video_frames_all[j:j + self.num_frames] current_all_intermediate_embeddings = [] for k, video_frame in enumerate(video_frames): current_id = video_id + '_' + video_frame dataset_entry = dataset.get({ 'video_id': video_id, 'frame_id': video_frame, 'unique_id': current_id }) embeddings_cropped, all_intermediate_embeddings, current_lstm_states_cropped = self.test_cropped_image( dataset_entry, current_lstm_states_cropped, return_all_intermediate_embeddings=True) current_all_intermediate_embeddings.append( all_intermediate_embeddings) if self.save_overlapping_embeddings: if j == 0 or k >= self.num_frames - 2: for i, e in enumerate(embeddings_cropped): current_all_embeddings_cropped[i].append( (e * 128).astype(np.int8)) if j == 0: for i in range(self.num_frames - 1): stacked_two_embeddings_tile_list = [] for tile_i in range( len(current_all_intermediate_embeddings[0])): stacked_two_embeddings_tile_list.append( np.stack([ current_all_intermediate_embeddings[i] [tile_i], current_all_intermediate_embeddings[ i + 1][tile_i] ], axis=self.time_stack_axis)) if self.save_instances: instances = self.get_merged_instances( stacked_two_embeddings_tile_list) instance_tracker.add_new_label_image(instances) if self.save_all_embeddings: for tile_i, e in enumerate( stacked_two_embeddings_tile_list): utils.io.image.write_np( e, self.output_file_for_current_iteration( name + '_' + video_id, 'embeddings', 'frame_' + str(j + i).zfill(3) + '_tile_' + str(tile_i).zfill(2) + '.mha'), compress=False) else: stacked_two_embeddings_tile_list = [] for tile_i in range( len(current_all_intermediate_embeddings[0])): stacked_two_embeddings_tile_list.append( np.stack([ current_all_intermediate_embeddings[ self.num_frames - 2][tile_i], current_all_intermediate_embeddings[ self.num_frames - 1][tile_i] ], axis=self.time_stack_axis)) if self.save_instances: instances = self.get_merged_instances( stacked_two_embeddings_tile_list) instance_tracker.add_new_label_image(instances) if self.save_all_embeddings: for tile_i, e in enumerate( stacked_two_embeddings_tile_list): utils.io.image.write_np( e, self.output_file_for_current_iteration( name + '_' + video_id, 'embeddings', 'frame_' + str(frame_index).zfill(3) + '_tile_' + str(tile_i).zfill(2) + '.mha'), compress=False) if j == 0: frame_index += self.num_frames - 1 else: frame_index += 1 if self.save_overlapping_embeddings: for i in range(len(current_all_embeddings_cropped)): if len(current_all_embeddings_cropped[i]) > 0: current_embeddings = np.stack( current_all_embeddings_cropped[i], axis=1) utils.io.image.write_np( current_embeddings, self.output_file_for_current_iteration( name + '_' + video_id, 'embeddings_cropped_' + str(i) + '.mha')) if self.save_instances: if self.save_intermediate_instance_images: utils.io.image.write_np( instance_tracker.stacked_label_image.astype(np.uint16), self.output_file_for_current_iteration( name + '_' + video_id, 'merged_instances.mha')) to_size = dataset_entry['datasources']['image'].GetSize() transformation = dataset_entry['transformations'][ 'image_transformation'] # transformation = scale_transformation_for_image_sizes(from_size, to_size, [0.95, 0.95] if self.dataset_name == 'Fluo-C2DL-MSC' else [1.0, 1.0]) instance_tracker.resample_stacked_label_image( to_size, transformation, 1.0) if self.save_intermediate_instance_images: utils.io.image.write_np( instance_tracker.stacked_label_image.astype(np.uint16), self.output_file_for_current_iteration( name + '_' + video_id, 'merged_instances_resampled.mha')) instance_tracker.finalize() if self.save_intermediate_instance_images: utils.io.image.write_np( instance_tracker.stacked_label_image.astype(np.uint16), self.output_file_for_current_iteration( name + '_' + video_id, 'merged_instances_final.mha')) if self.save_challenge_instance_images: track_tuples = instance_tracker.track_tuples final_track_image_np = instance_tracker.stacked_label_image print('Saving output images and tracks...') final_track_images_sitk = [ utils.sitk_np.np_to_sitk(np.squeeze(im)) for im in np.split(final_track_image_np, final_track_image_np.shape[0], axis=0) ] for i, final_track_image_sitk in enumerate( final_track_images_sitk): video_frame = str(i).zfill(3) utils.io.image.write( final_track_image_sitk, self.output_file_for_current_iteration( name + '_' + video_id, 'instances', self.image_prefix + video_frame + '.tif')) utils.io.text.save_list_csv( track_tuples, self.output_file_for_current_iteration( name + '_' + video_id, 'instances', self.track_file_name), delimiter=' ') # finalize loss values if update_loss: self.val_loss_aggregator.finalize(self.current_iter)