Ejemplo n.º 1
0
    def __init__(self, dataset_name, base_folder, output_folder, model_file_name, num_embeddings, image_size,
                 coord_factors, min_samples, sigma, border_size, parent_dilation, parent_frame_search):
        self.coord_factors = coord_factors
        self.min_samples = min_samples
        self.sigma = sigma
        self.border_size = border_size
        self.parent_dilation = parent_dilation
        self.parent_frame_search = parent_frame_search
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config)
        self.coord = tf.train.Coordinator()
        self.image_prefix = 'mask'
        self.track_file_name = 'res_track.txt'

        self.save_debug_images = False
        self.image_size = image_size
        self.data_format = 'channels_last'
        self.output_size = self.image_size
        self.num_embeddings = num_embeddings
        self.base_folder = base_folder
        self.output_folder = output_folder
        self.load_model_filename = model_file_name
        additional_scale = {'DIC-C2DH-HeLa': [1, 1],
                           'Fluo-C2DL-MSC': [0.95, 0.95],
                           'Fluo-N2DH-GOWT1': [1, 1],
                           'Fluo-N2DH-SIM+': [1, 1],
                           'Fluo-N2DL-HeLa': [1, 1],
                           'PhC-C2DH-U373': [1, 1],
                           'PhC-C2DL-PSC': [1, 1]}
        additional_scale = additional_scale[dataset_name]
        normalization_consideration_factors = {'DIC-C2DH-HeLa': (0.2, 0.1),
                                               'Fluo-C2DL-MSC': (0.2, 0.01),
                                               'Fluo-N2DH-GOWT1': (0.2, 0.1),
                                               'Fluo-N2DH-SIM+': (0.2, 0.01),
                                               'Fluo-N2DL-HeLa': (0.2, 0.1),
                                               'PhC-C2DH-U373': (0.2, 0.1),
                                               'PhC-C2DL-PSC': (0.2, 0.1)}
        normalization_consideration_factor = normalization_consideration_factors[dataset_name]
        self.dataset = Dataset(self.image_size,
                               base_folder=self.base_folder,
                               data_format=self.data_format,
                               save_debug_images=self.save_debug_images,
                               normalization_consideration_factors=normalization_consideration_factor,
                               additional_scale=additional_scale,
                               image_gaussian_blur_sigma=2.0,
                               pad_image=False)

        self.dataset_val = self.dataset.dataset_val_single_frame()

        if self.data_format == 'channels_first':
            self.channel_axis = 1
            self.time_stack_axis = 1
        else:
            self.channel_axis = 3
            self.time_stack_axis = 0
    def __init__(self, dataset_name, input_image_folder, output_folder,
                 model_file_name, num_embeddings, num_frames):
        self.dataset_name = dataset_name

        self.tf_config = tf.ConfigProto()
        self.tf_config.gpu_options.per_process_gpu_memory_fraction = 0.5
        self.tf_config.gpu_options.allow_growth = True

        self.sess = tf.Session(config=self.tf_config)
        self.coord = tf.train.Coordinator()
        self.image_prefix = 'mask'
        self.track_file_name = 'res_track.txt'

        self.save_all_embeddings = False
        self.save_all_input_images = False
        self.save_all_predictions = True
        self.image_size, self.test_image_size = image_sizes_for_dataset_name(
            dataset_name)
        self.tiled_increment = [128, 128]
        self.instances_ignore_border = [32, 32]
        self.num_frames = num_frames
        self.data_format = 'channels_last'
        self.output_size = self.image_size
        self.num_embeddings = num_embeddings
        self.input_image_folder = input_image_folder
        self.output_folder = output_folder
        self.load_model_filename = model_file_name
        self.dataset_parameters = get_dataset_parameters(dataset_name)
        self.instance_image_creator_parameters = get_instance_image_creator_parameters(
            dataset_name)
        self.instance_tracker_parameters = get_instance_tracker_parameters(
            dataset_name)
        self.dataset = Dataset(
            self.test_image_size,
            base_folder=self.input_image_folder,
            data_format=self.data_format,
            debug_image_folder=os.path.join(self.output_folder, 'input_images')
            if self.save_all_input_images else None,
            **self.dataset_parameters)

        self.dataset_val = self.dataset.dataset_val_single_frame()
        self.video_frames = glob(self.input_image_folder + '/*.tif')
        self.video_frames = sorted([
            os.path.splitext(os.path.basename(frame))[0][1:]
            for frame in self.video_frames
        ])

        if self.data_format == 'channels_first':
            self.channel_axis = 1
            self.time_stack_axis = 1
        else:
            self.channel_axis = 3
            self.time_stack_axis = 0
Ejemplo n.º 3
0
    def __init__(self, dataset_name, input_image_folder, output_folder,
                 model_file_name, num_embeddings, image_size, additional_scale,
                 normalization_consideration_factors, coord_factors,
                 min_samples, sigma, border_size, parent_dilation,
                 parent_frame_search):
        self.coord_factors = coord_factors
        self.min_samples = min_samples
        self.sigma = sigma
        self.border_size = border_size
        self.parent_dilation = parent_dilation
        self.parent_frame_search = parent_frame_search
        self.sess = tf.Session()
        self.coord = tf.train.Coordinator()
        self.image_prefix = 'mask'
        self.track_file_name = 'res_track.txt'

        self.save_all_embeddings = True
        self.save_all_input_images = True
        self.save_all_predictions = True
        self.save_debug_images = False
        self.image_size = image_size
        self.data_format = 'channels_last'
        self.output_size = self.image_size
        self.num_embeddings = num_embeddings
        self.input_image_folder = input_image_folder
        self.output_folder = output_folder
        self.load_model_filename = model_file_name
        self.dataset = Dataset(self.image_size,
                               base_folder=self.input_image_folder,
                               data_format=self.data_format,
                               save_debug_images=self.save_debug_images,
                               normalization_consideration_factors=
                               normalization_consideration_factors,
                               additional_scale=additional_scale,
                               image_gaussian_blur_sigma=2.0,
                               pad_image=False)

        self.dataset_val = self.dataset.dataset_val_single_frame()
        self.video_frames = glob.glob(self.input_image_folder + '/*.tif')
        self.video_frames = sorted([
            os.path.splitext(os.path.basename(frame))[0][1:]
            for frame in self.video_frames
        ])

        if self.data_format == 'channels_first':
            self.channel_axis = 1
            self.time_stack_axis = 1
        else:
            self.channel_axis = 3
            self.time_stack_axis = 0
Ejemplo n.º 4
0
class MainLoop(object):
    def __init__(self, dataset_name, input_image_folder, output_folder,
                 model_file_name, num_embeddings, num_frames):
        self.dataset_name = dataset_name
        self.sess = tf.Session()
        self.coord = tf.train.Coordinator()
        self.image_prefix = 'mask'
        self.track_file_name = 'res_track.txt'

        self.save_all_embeddings = False
        self.save_all_input_images = False
        self.save_all_predictions = True
        self.image_size, self.test_image_size = image_sizes_for_dataset_name(
            dataset_name)
        self.tiled_increment = [128, 128]
        self.instances_ignore_border = [32, 32]
        self.num_frames = num_frames
        self.data_format = 'channels_first'
        self.output_size = self.image_size
        self.num_embeddings = num_embeddings
        self.input_image_folder = input_image_folder
        self.output_folder = output_folder
        self.load_model_filename = model_file_name
        self.dataset_parameters = get_dataset_parameters(dataset_name)
        self.instance_image_creator_parameters = get_instance_image_creator_parameters(
            dataset_name)
        self.instance_tracker_parameters = get_instance_tracker_parameters(
            dataset_name)
        self.dataset = Dataset(
            self.test_image_size,
            base_folder=self.input_image_folder,
            data_format=self.data_format,
            debug_image_folder=os.path.join(self.output_folder, 'input_images')
            if self.save_all_input_images else None,
            **self.dataset_parameters)

        self.dataset_val = self.dataset.dataset_val_single_frame()
        self.video_frames = glob(self.input_image_folder + '/*.tif')
        self.video_frames = sorted([
            os.path.splitext(os.path.basename(frame))[0][1:]
            for frame in self.video_frames
        ])

        if self.data_format == 'channels_first':
            self.channel_axis = 1
            self.time_stack_axis = 1
        else:
            self.channel_axis = 3
            self.time_stack_axis = 0

    def create_output_folder(self):
        create_directories(self.output_folder)

    def load_model(self):
        self.saver = tf.train.Saver()
        print('Restoring model ' + self.load_model_filename)
        self.saver.restore(self.sess, self.load_model_filename)

    def init_all(self):
        self.init_networks()
        self.load_model()
        self.create_output_folder()

    def run_test(self):
        self.init_all()
        print('Starting main test loop')
        self.test()

    def init_networks(self):
        network_image_size = self.image_size
        network_output_size = self.output_size

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image', [1, self.num_frames] + network_image_size),
                ('instances_merged', [None] + network_output_size),
                ('instances_bac', [None] + network_output_size)
            ])
        else:
            data_generator_entries = OrderedDict([
                ('image', network_image_size + [1]),
                ('instances_merged', network_output_size + [None]),
                ('instances_bac', network_output_size + [None])
            ])

        # build val graph
        val_placeholders = create_placeholders(data_generator_entries,
                                               shape_prefix=[1])
        self.data_val = val_placeholders['image']

        with tf.variable_scope('net'):
            self.embeddings_0, self.embeddings_1 = network(
                self.data_val,
                num_outputs_embedding=self.num_embeddings,
                data_format=self.data_format,
                actual_network=HourglassNet3D,
                is_training=False)
            self.embeddings_normalized_0 = tf.nn.l2_normalize(
                self.embeddings_0, dim=self.channel_axis)
            self.embeddings_normalized_1 = tf.nn.l2_normalize(
                self.embeddings_1, dim=self.channel_axis)
            self.embeddings_cropped_val = (self.embeddings_normalized_0,
                                           self.embeddings_normalized_1)

    def test_cropped_image(self,
                           dataset_entry,
                           return_all_intermediate_embeddings=False):
        generators = dataset_entry['generators']
        full_image = generators['image']
        # initialize sizes based on data_format
        fetches = self.embeddings_cropped_val
        if self.data_format == 'channels_first':
            image_size_np = [1, self.num_frames] + list(
                reversed(self.image_size))
            full_image_size_np = list(full_image.shape)
            embeddings_size_np = [self.num_embeddings, self.num_frames] + list(
                reversed(self.image_size))
            full_embeddings_size_np = [self.num_embeddings] + list(
                full_image.shape[1:])
            inc = [0, 0] + list(reversed(self.tiled_increment))
        else:
            image_size_np = list(reversed(self.image_size)) + [1]
            full_image_size_np = list(full_image.shape)
            embeddings_size_np = list(reversed(
                self.image_size)) + [self.num_embeddings]
            full_embeddings_size_np = list(
                full_image.shape[0:2]) + [self.num_embeddings]
            inc = list(reversed(self.tiled_increment)) + [0]
        # initialize on image tiler for the input and a list of image tilers for the embeddings
        image_tiler = ImageTiler(full_image_size_np, image_size_np, inc, True,
                                 -1)
        embeddings_tilers = tuple([
            ImageTiler(full_embeddings_size_np, embeddings_size_np, inc, True,
                       -1) for _ in range(len(self.embeddings_cropped_val))
        ])

        all_intermediate_embeddings = []
        for state_index, all_tilers in enumerate(
                zip(*((image_tiler, ) + embeddings_tilers))):
            image_tiler = all_tilers[0]
            embeddings_tilers = all_tilers[1:]
            current_image = image_tiler.get_current_data(full_image)
            feed_dict = {self.data_val: np.expand_dims(current_image, axis=0)}
            run_tuple = self.sess.run(fetches, feed_dict)
            image_tiler.set_current_data(current_image)
            for i, embeddings_tiler in enumerate(embeddings_tilers):
                embeddings = np.squeeze(run_tuple[i], axis=0)
                if return_all_intermediate_embeddings and i == len(
                        embeddings_tilers) - 1:
                    all_intermediate_embeddings.append(embeddings)
                embeddings_tiler.set_current_data(embeddings)

        embeddings = [
            embeddings_tiler.output_image
            for embeddings_tiler in embeddings_tilers
        ]

        if return_all_intermediate_embeddings:
            return embeddings, all_intermediate_embeddings
        else:
            return embeddings

    def merge_tiled_instances(self, tiled_instances):
        # initialize sizes based on data_format
        instances_size_np = [2] + list(reversed(self.image_size))
        full_instances_size_np = [2] + list(reversed(self.test_image_size))
        inc = [0] + list(reversed(self.tiled_increment))
        # initialize on image tiler for the input and a list of image tilers for the embeddings
        instance_tiler = ImageTiler(full_instances_size_np,
                                    instances_size_np,
                                    inc,
                                    True,
                                    0,
                                    output_image_dtype=np.uint16)
        instance_merger = InstanceMerger(
            ignore_border=self.instances_ignore_border)

        for i, instance_tiler in enumerate(instance_tiler):
            current_instance_pair = tiled_instances[i]
            instance_tiler.set_current_data(
                current_instance_pair,
                instance_merger.merge_as_larger_instances,
                merge_whole_image=True)
        instances = instance_tiler.output_image

        return instances

    def get_instances(self, stacked_two_embeddings):
        clusterer = InstanceImageCreator(
            **self.instance_image_creator_parameters)
        clusterer.create_instance_image(stacked_two_embeddings)
        instances = clusterer.label_image
        return instances

    def get_merged_instances(self, stacked_two_embeddings_tile_list):
        tiled_instances = []
        for stacked_two_embeddings in stacked_two_embeddings_tile_list:
            if self.data_format == 'channels_last':
                stacked_two_embeddings = np.transpose(stacked_two_embeddings,
                                                      [3, 0, 1, 2])
            instances = self.get_instances(stacked_two_embeddings)
            tiled_instances.append(instances)
        merged_instances = self.merge_tiled_instances(tiled_instances)
        return merged_instances

    def test(self):
        if len(self.video_frames) == 0:
            print('No images found!')
            return
        print('Testing...', self.input_image_folder)

        video_frames_all = self.video_frames
        frame_index = 0
        instance_tracker = InstanceTracker(**self.instance_tracker_parameters)
        for j in range(len(video_frames_all) - self.num_frames + 1):
            print('Processing frame', j)
            video_frames = video_frames_all[j:j + self.num_frames]
            current_all_intermediate_embeddings = []
            images = []
            for k, video_frame in enumerate(video_frames):
                dataset_entry = self.dataset_val.get({'image_id': video_frame})
                image = dataset_entry['generators']['image']
                images.append(image)
            stacked_image = np.stack(images, axis=1)
            print(stacked_image.shape)
            utils.io.image.write_np(
                stacked_image,
                os.path.join(self.output_folder, 'input',
                             'frame_' + str(j) + '.mha'))
            full_dataset_entry = {'generators': {'image': stacked_image}}
            embeddings_cropped, all_intermediate_embeddings = self.test_cropped_image(
                full_dataset_entry, return_all_intermediate_embeddings=True)
            current_all_intermediate_embeddings.append(
                all_intermediate_embeddings)

            if j == 0:
                for i in range(self.num_frames - 1):
                    stacked_two_embeddings_tile_list = []
                    for tile_i in range(
                            len(current_all_intermediate_embeddings[0])):
                        stacked_two_embeddings_tile_list.append(
                            np.stack([
                                current_all_intermediate_embeddings[0][tile_i]
                                [:, i, :, :],
                                current_all_intermediate_embeddings[0][tile_i]
                                [:, i + 1, :, :]
                            ],
                                     axis=self.time_stack_axis))
                    instances = self.get_merged_instances(
                        stacked_two_embeddings_tile_list)
                    instance_tracker.add_new_label_image(instances)
                    if self.save_all_embeddings:
                        for tile_i, e in enumerate(
                                stacked_two_embeddings_tile_list):
                            utils.io.image.write_np(
                                e,
                                os.path.join(
                                    self.output_folder, 'embeddings',
                                    'frame_' + str(j + i).zfill(3) + '_tile_' +
                                    str(tile_i).zfill(2) + '.mha'),
                                compress=False)
            else:
                stacked_two_embeddings_tile_list = []
                for tile_i in range(len(
                        current_all_intermediate_embeddings[0])):
                    stacked_two_embeddings_tile_list.append(
                        np.stack([
                            current_all_intermediate_embeddings[0][tile_i]
                            [:, num_frames - 2, :, :],
                            current_all_intermediate_embeddings[0][tile_i]
                            [:, num_frames - 1, :, :]
                        ],
                                 axis=self.time_stack_axis))
                instances = self.get_merged_instances(
                    stacked_two_embeddings_tile_list)
                instance_tracker.add_new_label_image(instances)
                if self.save_all_embeddings:
                    for tile_i, e in enumerate(
                            stacked_two_embeddings_tile_list):
                        utils.io.image.write_np(
                            e,
                            os.path.join(
                                self.output_folder, 'embeddings',
                                'frame_' + str(frame_index).zfill(3) +
                                '_tile_' + str(tile_i).zfill(2) + '.mha'),
                            compress=False)
            if j == 0:
                frame_index += self.num_frames - 1
            else:
                frame_index += 1

            if self.save_all_predictions:
                utils.io.image.write_np(
                    instance_tracker.stacked_label_image.astype(np.uint16),
                    os.path.join(self.output_folder, 'predictions',
                                 'merged_instances.mha'))

        to_size = dataset_entry['datasources']['image'].GetSize()
        transformation = dataset_entry['transformations']['image']
        #transformation = scale_transformation_for_image_sizes(from_size, to_size, [0.95, 0.95] if self.dataset_name == 'Fluo-C2DL-MSC' else [1.0, 1.0])
        instance_tracker.resample_stacked_label_image(to_size, transformation,
                                                      1.0)
        if self.save_all_predictions:
            utils.io.image.write_np(
                instance_tracker.stacked_label_image.astype(np.uint16),
                os.path.join(self.output_folder, 'predictions',
                             'merged_instances_resampled.mha'))

        instance_tracker.finalize()
        if self.save_all_predictions:
            utils.io.image.write_np(
                instance_tracker.stacked_label_image.astype(np.uint16),
                os.path.join(self.output_folder, 'predictions',
                             'merged_instances_final.mha'))

        track_tuples = instance_tracker.track_tuples
        final_track_image_np = instance_tracker.stacked_label_image

        print('Saving output images and tracks...')
        final_track_images_sitk = [
            utils.sitk_np.np_to_sitk(np.squeeze(im)) for im in np.split(
                final_track_image_np, final_track_image_np.shape[0], axis=0)
        ]
        for i, final_track_image_sitk in enumerate(final_track_images_sitk):
            video_frame = str(i).zfill(3)
            utils.io.image.write(
                final_track_image_sitk,
                os.path.join(self.output_folder,
                             self.image_prefix + video_frame + '.tif'))
        utils.io.text.save_list_csv(track_tuples,
                                    os.path.join(self.output_folder,
                                                 self.track_file_name),
                                    delimiter=' ')
Ejemplo n.º 5
0
class MainLoop(object):
    def __init__(self, dataset_name, input_image_folder, output_folder,
                 model_file_name, num_embeddings, image_size, additional_scale,
                 normalization_consideration_factors, coord_factors,
                 min_samples, sigma, border_size, parent_dilation,
                 parent_frame_search):
        self.coord_factors = coord_factors
        self.min_samples = min_samples
        self.sigma = sigma
        self.border_size = border_size
        self.parent_dilation = parent_dilation
        self.parent_frame_search = parent_frame_search
        self.sess = tf.Session()
        self.coord = tf.train.Coordinator()
        self.image_prefix = 'mask'
        self.track_file_name = 'res_track.txt'

        self.save_all_embeddings = True
        self.save_all_input_images = True
        self.save_all_predictions = True
        self.save_debug_images = False
        self.image_size = image_size
        self.data_format = 'channels_last'
        self.output_size = self.image_size
        self.num_embeddings = num_embeddings
        self.input_image_folder = input_image_folder
        self.output_folder = output_folder
        self.load_model_filename = model_file_name
        self.dataset = Dataset(self.image_size,
                               base_folder=self.input_image_folder,
                               data_format=self.data_format,
                               save_debug_images=self.save_debug_images,
                               normalization_consideration_factors=
                               normalization_consideration_factors,
                               additional_scale=additional_scale,
                               image_gaussian_blur_sigma=2.0,
                               pad_image=False)

        self.dataset_val = self.dataset.dataset_val_single_frame()
        self.video_frames = glob.glob(self.input_image_folder + '/*.tif')
        self.video_frames = sorted([
            os.path.splitext(os.path.basename(frame))[0][1:]
            for frame in self.video_frames
        ])

        if self.data_format == 'channels_first':
            self.channel_axis = 1
            self.time_stack_axis = 1
        else:
            self.channel_axis = 3
            self.time_stack_axis = 0

    def create_output_folder(self):
        create_directories(self.output_folder)

    def load_model(self):
        self.saver = tf.train.Saver()
        print('Restoring model ' + self.load_model_filename)
        self.saver.restore(self.sess, self.load_model_filename)

    def init_all(self):
        self.init_networks()
        self.load_model()
        self.create_output_folder()

    def run_test(self):
        self.init_all()
        print('Starting main test loop')
        self.test()

    def init_networks(self):
        network_image_size = self.image_size
        network_output_size = self.output_size

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image', [1] + network_image_size),
                ('instances_merged', [None] + network_output_size),
                ('instances_bac', [None] + network_output_size)
            ])
        else:
            data_generator_entries = OrderedDict([
                ('image', network_image_size + [1]),
                ('instances_merged', network_output_size + [None]),
                ('instances_bac', network_output_size + [None])
            ])

        # build val graph
        val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders(
            data_generator_entries, shape_prefix=[1])
        self.data_val = val_placeholders['image']
        self.tracking_val = val_placeholders['instances_merged']
        self.instances_bac_val = val_placeholders['instances_bac']

        with tf.variable_scope('net/rnn'):
            self.embeddings_0, self.embeddings_1, self.lstm_input_states, self.lstm_output_states = network_single_frame_with_lstm_states(
                self.data_val,
                num_outputs_embedding=self.num_embeddings,
                data_format=self.data_format)
            self.embeddings_normalized_0 = tf.nn.l2_normalize(
                self.embeddings_0, dim=self.channel_axis)
            self.embeddings_normalized_1 = tf.nn.l2_normalize(
                self.embeddings_1, dim=self.channel_axis)

    def test(self):
        #label_stack = utils.sitk_np.sitk_to_np(utils.io.image.read(os.path.join(self.output_folder, 'label_stack.mha'), sitk_pixel_type=sitk.sitkVectorUInt16))
        #label_stack = np.transpose(label_stack, [0, 3, 1, 2])

        if len(self.video_frames) == 0:
            print('No images found!')
            return
        print('Testing...', self.input_image_folder)
        tracker = EmbeddingTracker(
            coord_factors=self.coord_factors,
            stack_neighboring_slices=2,
            min_cluster_size=self.min_samples,
            min_samples=self.min_samples,
            min_label_size_per_stack=self.min_samples / 2,
            save_label_stack=True,
            image_ignore_border=self.border_size,
            parent_search_dilation_size=self.parent_dilation,
            max_parent_search_frames=self.parent_frame_search)
        #tracker.set_label_stack(label_stack)
        first = True
        current_images = []
        current_lstm_states = []
        for i, video_frame in enumerate(self.video_frames):
            with Timer('processing video frame ' + str(video_frame)):
                dataset_entry = self.dataset_val.get({'image_id': video_frame})
                generators = dataset_entry['generators']
                feed_dict = {
                    self.data_val: np.expand_dims(generators['image'], axis=0)
                }
                if not first:
                    for j in range(len(self.lstm_input_states)):
                        feed_dict[
                            self.lstm_input_states[j]] = current_lstm_states[j]
                run_tuple = self.sess.run([self.embeddings_normalized_1] +
                                          list(self.lstm_output_states),
                                          feed_dict=feed_dict)
                embeddings_normalized_1 = np.squeeze(run_tuple[0], axis=0)
                current_lstm_states = run_tuple[1:]
                if self.data_format == 'channels_last':
                    embeddings_normalized_1 = np.transpose(
                        embeddings_normalized_1, [2, 0, 1])
                tracker.add_slice(embeddings_normalized_1)
                if self.save_all_input_images:
                    current_images.append(generators['image'])
                first = False

        # finalize tracker and resample to input resolution
        transformations = dataset_entry['transformations']
        transformation = transformations['image']
        datasources = dataset_entry['datasources']
        input_image = datasources['image']

        #utils.io.image.write_np(np.stack(tracker.label_stack_list, axis=1), os.path.join(self.output_folder, 'label_stack.mha'))
        tracker.finalize()
        tracker.resample_stacked_label_image(input_image, transformation,
                                             self.sigma)
        tracker.fix_tracks_after_resampling()
        track_tuples = tracker.track_tuples
        final_track_image_np = tracker.stacked_label_image

        print('Saving output images and tracks...')
        final_track_images_sitk = [
            utils.sitk_np.np_to_sitk(np.squeeze(im)) for im in np.split(
                final_track_image_np, final_track_image_np.shape[0], axis=0)
        ]
        for video_frame, final_track_image_sitk in zip(
                self.video_frames, final_track_images_sitk):
            utils.io.image.write(
                final_track_image_sitk,
                os.path.join(self.output_folder,
                             self.image_prefix + video_frame + '.tif'))
        utils.io.text.save_list_csv(track_tuples,
                                    os.path.join(self.output_folder,
                                                 self.track_file_name),
                                    delimiter=' ')

        if self.save_all_embeddings:
            # embeddings are always 'channels_first'
            embeddings = np.stack(tracker.embeddings_slices, axis=1)
            utils.io.image.write_np(
                embeddings, os.path.join(self.output_folder, 'embeddings.mha'),
                'channels_first')
        if self.save_all_input_images:
            images = np.stack(current_images, axis=self.time_stack_axis)
            utils.io.image.write_np(
                images, os.path.join(self.output_folder, 'image.mha'),
                self.data_format)
        if self.save_all_predictions:
            predictions = np.stack(tracker.stacked_label_image,
                                   axis=self.time_stack_axis)
            utils.io.image.write_np(
                predictions, os.path.join(self.output_folder,
                                          'predictions.mha'), self.data_format)
class MainLoop(object):
    def __init__(self, dataset_name, base_folder, output_folder,
                 model_file_name, num_embeddings, image_size, coord_factors,
                 min_samples, sigma, border_size, parent_dilation,
                 parent_frame_search):
        self.coord_factors = coord_factors
        self.min_samples = min_samples
        self.sigma = sigma
        self.border_size = border_size
        self.parent_dilation = parent_dilation
        self.parent_frame_search = parent_frame_search
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config)
        self.coord = tf.train.Coordinator()
        self.image_prefix = 'mask'
        self.track_file_name = 'res_track.txt'

        self.save_debug_images = False
        self.image_size = image_size
        self.data_format = 'channels_last'
        self.output_size = self.image_size
        self.num_embeddings = num_embeddings
        self.base_folder = base_folder
        self.output_folder = output_folder
        self.load_model_filename = model_file_name
        additional_scale = {
            'DIC-C2DH-HeLa': [1, 1],
            'Fluo-C2DL-MSC': [0.95, 0.95],
            'Fluo-N2DH-GOWT1': [1, 1],
            'Fluo-N2DH-SIM+': [1, 1],
            'Fluo-N2DL-HeLa': [1, 1],
            'PhC-C2DH-U373': [1, 1],
            'PhC-C2DL-PSC': [1, 1]
        }
        additional_scale = additional_scale[dataset_name]
        normalization_consideration_factors = {
            'DIC-C2DH-HeLa': (0.2, 0.1),
            'Fluo-C2DL-MSC': (0.2, 0.01),
            'Fluo-N2DH-GOWT1': (0.2, 0.1),
            'Fluo-N2DH-SIM+': (0.2, 0.01),
            'Fluo-N2DL-HeLa': (0.2, 0.1),
            'PhC-C2DH-U373': (0.2, 0.1),
            'PhC-C2DL-PSC': (0.2, 0.1)
        }
        normalization_consideration_factor = normalization_consideration_factors[
            dataset_name]
        self.dataset = Dataset(self.image_size,
                               base_folder=self.base_folder,
                               data_format=self.data_format,
                               save_debug_images=self.save_debug_images,
                               normalization_consideration_factors=
                               normalization_consideration_factor,
                               additional_scale=additional_scale,
                               image_gaussian_blur_sigma=2.0,
                               pad_image=False)

        self.dataset_val = self.dataset.dataset_val_single_frame()

        if self.data_format == 'channels_first':
            self.channel_axis = 1
            self.time_stack_axis = 1
        else:
            self.channel_axis = 3
            self.time_stack_axis = 0

    def __del__(self):
        self.sess.close()
        tf.reset_default_graph()

    # def loadModel(self):
    #    saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net/first_frame_net') + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net/data_conv'))
    #    model_filename = 'weights/model-' + str(self.current_iter)
    #    print('Restoring model ' + model_filename)
    #    saver.restore(self.sess, model_filename)

    def init_networks(self):
        network_image_size = self.image_size
        network_output_size = self.output_size

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image', [1] + network_image_size),
                ('instances_merged', [None] + network_output_size),
                ('instances_bac', [None] + network_output_size)
            ])
        else:
            data_generator_entries = OrderedDict([
                ('image', network_image_size + [1]),
                ('instances_merged', network_output_size + [None]),
                ('instances_bac', network_output_size + [None])
            ])

        # build val graph
        val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders(
            data_generator_entries, shape_prefix=[1])
        self.data_val = val_placeholders['image']
        self.tracking_val = val_placeholders['instances_merged']
        self.instances_bac_val = val_placeholders['instances_bac']

        with tf.variable_scope('net/rnn'):
            self.embeddings_val, self.embeddings_2_val, self.lstm_input_states_val, self.lstm_output_states_val = network_single_frame_with_lstm_states(
                self.data_val,
                num_outputs_embedding=self.num_embeddings,
                data_format=self.data_format)
            self.embeddings_normalized_val = tf.nn.l2_normalize(
                self.embeddings_val, dim=self.channel_axis)
            self.embeddings_normalized_2_val = tf.nn.l2_normalize(
                self.embeddings_2_val, dim=self.channel_axis)

    def test(self):
        print('Testing...', self.base_folder)
        video_frames = glob.glob(self.base_folder + '*.tif')
        video_frames = sorted([
            os.path.splitext(os.path.basename(frame))[0][1:]
            for frame in video_frames
        ])
        #video_frames = video_frames[100:]

        #coord_factors = 0.001
        #min_cluster_size = 100
        #min_samples = 100
        #min_label_size_per_stack = 100
        tracker = EmbeddingTracker(
            coord_factors=self.coord_factors,
            stack_neighboring_slices=2,
            min_cluster_size=self.min_samples,
            min_samples=self.min_samples,
            min_label_size_per_stack=self.min_samples,
            save_label_stack=True,
            image_ignore_border=self.border_size,
            parent_search_dilation_size=self.parent_dilation,
            max_parent_search_frames=self.parent_frame_search)

        first = True
        current_predictions = []
        current_predictions_2 = []
        current_images = []
        # reset_every_frames = 20
        for i, video_frame in enumerate(video_frames):
            #if int(video_frame) < 150 or int(video_frame) > 250:
            #    continue
            with Timer('processing video frame ' + str(video_frame)):
                dataset_entry = self.dataset_val.get({'image_id': video_frame})
                datasources = dataset_entry['datasources']
                generators = dataset_entry['generators']
                feed_dict = {
                    self.data_val: np.expand_dims(generators['image'], axis=0)
                }
                # run loss and update loss accumulators
                if not first:
                    for i in range(len(self.lstm_input_states_val)):
                        feed_dict[self.lstm_input_states_val[
                            i]] = current_lstm_states[i]

                run_tuple = self.sess.run([
                    self.embeddings_normalized_val,
                    self.embeddings_normalized_2_val
                ] + list(self.lstm_output_states_val),
                                          feed_dict=feed_dict)
                # print(iv[0].decode())
                embeddings_softmax = np.squeeze(run_tuple[0], axis=0)
                embeddings_softmax_2 = np.squeeze(run_tuple[1], axis=0)
                current_lstm_states = run_tuple[2:]
                #current_predictions.append(embeddings_softmax)
                #current_predictions_2.append(embeddings_softmax_2)
                current_images.append(generators['image'])
                # current_instances.append(instance_segmentation_test.get_instances_cosine_kmeans_2d(embeddings_softmax))
                first = False

                datasources = dataset_entry['datasources']
                input_image = datasources['image']
                transformations = dataset_entry['transformations']
                transformation = transformations['image']
                # embeddings_original = utils.sitk_image.transform_np_output_to_sitk_input(embeddings_softmax_2,
                #                                                                        output_spacing=None,
                #                                                                        channel_axis=2,
                #                                                                        input_image_sitk=input_image,
                #                                                                        transform=transformation,
                #                                                                        interpolator='linear',
                #                                                                        output_pixel_type=sitk.sitkFloat32)
                # embeddings_softmax_2 = utils.sitk_np.sitk_list_to_np(embeddings_original, axis=2)

                current_predictions_2.append(embeddings_softmax_2)

                tracker.add_slice(np.transpose(embeddings_softmax_2,
                                               [2, 0, 1]))

                if tracker.stacked_label_image is not None:
                    utils.io.image.write_np(
                        tracker.stacked_label_image,
                        os.path.join(self.output_folder, 'merged.mha'))

                # if not first and i % reset_every_frames != 0:
                #     run_tuple = self.sess.run([self.embeddings_normalized_val, self.embeddings_normalized_2_val] + list(self.lstm_output_states_val), feed_dict=feed_dict)
                #     embeddings_softmax_2 = np.squeeze(run_tuple[1], axis=0)
                #     tracker.add_reset_slice(np.transpose(embeddings_softmax_2, [2, 0, 1]))

        # prediction = np.stack(current_predictions, axis=self.time_stack_axis)
        # del current_predictions
        # utils.io.image.write_np(prediction, os.path.join(self.output_folder, 'embeddings.mha'), self.data_format)
        # del prediction
        prediction_2 = np.stack(current_predictions_2,
                                axis=self.time_stack_axis)
        del current_predictions_2
        utils.io.image.write_np(
            prediction_2, os.path.join(self.output_folder, 'embeddings_2.mha'),
            self.data_format)
        del prediction_2
        images = np.stack(current_images, axis=self.time_stack_axis)
        del current_images
        utils.io.image.write_np(images,
                                os.path.join(self.output_folder, 'image.mha'),
                                self.data_format)
        del images
        transformations = dataset_entry['transformations']
        transformation = transformations['image']
        sitk.WriteTransform(transformation,
                            os.path.join(self.output_folder, 'transform.txt'))

        #if self.data_format == 'channels_last':
        #    prediction_2 = np.transpose(prediction_2, [3, 0, 1, 2])

        # two_slices = tracker.get_instances_cosine_dbscan_slice_by_slice(prediction_2)
        # utils.io.image.write_np(two_slices, os.path.join(self.output_folder, 'two_slices.mha'))
        # merged = tracker.merge_consecutive_slices(two_slices, slice_neighbour_size=2)
        # utils.io.image.write_np(merged, os.path.join(self.output_folder, 'merged.mha'), self.data_format)

        datasources = dataset_entry['datasources']
        input_image = datasources['image']
        if self.sigma == 1:
            interpolator = 'label_gaussian'
        else:
            interpolator = 'nearest'

        tracker.finalize()
        track_tuples = tracker.track_tuples
        merged = tracker.stacked_label_image

        final_predictions = utils.sitk_image.transform_np_output_to_sitk_input(
            merged,
            output_spacing=None,
            channel_axis=0,
            input_image_sitk=input_image,
            transform=transformation,
            interpolator=interpolator,
            output_pixel_type=sitk.sitkUInt16)
        #final_predictions = [utils.sitk_np.np_to_sitk(np.squeeze(im), type=np.uint16) for im in np.split(merged, merged.shape[0], axis=0)]
        #final_predictions_smoothed_2 = [utils.sitk_image.apply_np_image_function(im, lambda x: self.label_smooth(x, sigma=2)) for im in final_predictions]
        if self.sigma > 1:
            final_predictions = [
                utils.sitk_image.apply_np_image_function(
                    im, lambda x: self.label_smooth(x, sigma=self.sigma))
                for im in final_predictions
            ]

        for video_frame, final_prediction in zip(video_frames,
                                                 final_predictions):
            utils.io.image.write(
                final_prediction,
                os.path.join(self.output_folder,
                             self.image_prefix + video_frame + '.tif'))

        utils.io.image.write_np(
            np.stack(tracker.label_stack_list, axis=1),
            os.path.join(self.output_folder, 'label_stack.mha'))

        final_predictions_stacked = utils.sitk_image.accumulate(
            final_predictions)
        utils.io.image.write(final_predictions_stacked,
                             os.path.join(self.output_folder, 'stacked.mha'))
        #utils.io.image.write(utils.sitk_image.accumulate(final_predictions_smoothed_2), os.path.join(self.output_folder, 'stacked_2.mha'))
        #utils.io.image.write(utils.sitk_image.accumulate(final_predictions_smoothed_4), os.path.join(self.output_folder, 'stacked_4.mha'))

        print(track_tuples)
        utils.io.text.save_list_csv(track_tuples,
                                    os.path.join(self.output_folder,
                                                 self.track_file_name),
                                    delimiter=' ')

    def label_smooth(self, im, sigma):
        label_images, labels = utils.np_image.split_label_image_with_unknown_labels(
            im, dtype=np.float32)
        smoothed_label_images = utils.np_image.smooth_label_images(
            label_images, sigma=sigma, dtype=im.dtype)
        return utils.np_image.merge_label_images(smoothed_label_images, labels)

    def testx(self):
        label_stack = utils.sitk_np.sitk_to_np(
            utils.io.image.read(os.path.join(self.output_folder,
                                             'label_stack.mha'),
                                sitk_pixel_type=sitk.sitkVectorUInt16))
        label_stack = np.transpose(label_stack, [0, 3, 1, 2])
        tracker = EmbeddingTracker(
            coord_factors=self.coord_factors,
            stack_neighboring_slices=2,
            min_cluster_size=self.min_samples * 2,
            min_samples=self.min_samples,
            min_label_size_per_stack=self.min_samples,
            save_label_stack=True,
            image_ignore_border=self.border_size,
            parent_search_dilation_size=self.parent_dilation)
        tracker.set_label_stack(label_stack)

        video_frames = glob.glob(self.base_folder + '*.tif')
        video_frames = sorted([
            os.path.splitext(os.path.basename(frame))[0][1:]
            for frame in video_frames
        ])

        dataset_entry = self.dataset_val.get({'image_id': video_frames[0]})
        datasources = dataset_entry['datasources']
        input_image = datasources['image']
        transformations = dataset_entry['transformations']
        transformation = transformations['image']

        datasources = dataset_entry['datasources']
        input_image = datasources['image']
        if self.sigma == 1:
            interpolator = 'label_gaussian'
        else:
            interpolator = 'nearest'

        tracker.finalize()
        track_tuples = tracker.track_tuples
        merged = tracker.stacked_label_image

        final_predictions = utils.sitk_image.transform_np_output_to_sitk_input(
            merged,
            output_spacing=None,
            channel_axis=0,
            input_image_sitk=input_image,
            transform=transformation,
            interpolator=interpolator,
            output_pixel_type=sitk.sitkUInt16)
        # final_predictions = [utils.sitk_np.np_to_sitk(np.squeeze(im), type=np.uint16) for im in np.split(merged, merged.shape[0], axis=0)]
        # final_predictions_smoothed_2 = [utils.sitk_image.apply_np_image_function(im, lambda x: self.label_smooth(x, sigma=2)) for im in final_predictions]
        if self.sigma > 1:
            final_predictions = [
                utils.sitk_image.apply_np_image_function(
                    im, lambda x: self.label_smooth(x, sigma=self.sigma))
                for im in final_predictions
            ]

        for video_frame, final_prediction in zip(video_frames,
                                                 final_predictions):
            utils.io.image.write(
                final_prediction,
                os.path.join(self.output_folder,
                             self.image_prefix + video_frame + '.tif'))

        utils.io.image.write_np(
            np.stack(tracker.label_stack_list, axis=1),
            os.path.join(self.output_folder, 'label_stack.mha'))

        final_predictions_stacked = utils.sitk_image.accumulate(
            final_predictions)
        utils.io.image.write(final_predictions_stacked,
                             os.path.join(self.output_folder, 'stacked.mha'))
        # utils.io.image.write(utils.sitk_image.accumulate(final_predictions_smoothed_2), os.path.join(self.output_folder, 'stacked_2.mha'))
        # utils.io.image.write(utils.sitk_image.accumulate(final_predictions_smoothed_4), os.path.join(self.output_folder, 'stacked_4.mha'))

        print(track_tuples)
        utils.io.text.save_list_csv(track_tuples,
                                    os.path.join(self.output_folder,
                                                 self.track_file_name),
                                    delimiter=' ')

    def create_output_folder(self):
        create_directories(self.output_folder)

    def load_model(self):
        self.saver = tf.train.Saver()
        self.saver = tf.train.Saver()
        print('Restoring model ' + self.load_model_filename)
        self.saver.restore(self.sess, self.load_model_filename)

    def init_all(self):
        self.init_networks()
        self.load_model()
        self.create_output_folder()

    def run_test(self):
        self.init_all()
        print('Starting main test loop')
        self.test()
Ejemplo n.º 7
0
    def __init__(self, dataset_name, input_image_folder, output_folder,
                 model_file_name, num_embeddings, num_frames):
        self.dataset_name = dataset_name
        self.sess = tf.Session()
        self.coord = tf.train.Coordinator()
        self.image_prefix = 'mask'
        self.track_file_name = 'res_track.txt'

        self.save_all_embeddings = False
        self.save_all_input_images = False
        self.save_all_predictions = True
        self.hdbscan = False
        self.image_size, self.test_image_size = image_sizes_for_dataset_name(
            dataset_name)
        self.tiled_increment = [128, 128]
        self.instances_ignore_border = [32, 32]
        self.num_frames = num_frames
        self.data_format = 'channels_last'
        self.output_size = self.image_size
        self.num_embeddings = num_embeddings
        self.input_image_folder = input_image_folder
        self.output_folder = output_folder
        self.load_model_filename = model_file_name
        self.dataset_parameters = get_dataset_parameters(dataset_name)
        self.instance_image_creator_parameters = get_instance_image_creator_parameters(
            dataset_name)
        if self.hdbscan == True:
            if self.dataset_name == 'DIC-C2DH-HeLa':
                min_cluster_size = 1000
                coord_factors = 0.02
            if self.dataset_name == 'Fluo-N2DH-GOWT1':
                min_cluster_size = 1000
                coord_factors = 0.001
            if self.dataset_name == 'PhC-C2DH-U373':
                min_cluster_size = 500
                coord_factors = 0.005
            if self.dataset_name == 'Fluo-N2DL-HeLa':
                min_cluster_size = 25
                coord_factors = 0.01
            self.instance_image_creator_parameters[
                'min_cluster_size'] = min_cluster_size
            self.instance_image_creator_parameters[
                'coord_factors'] = coord_factors
            self.instance_image_creator_parameters['hdbscan'] = True
        self.instance_tracker_parameters = get_instance_tracker_parameters(
            dataset_name)
        self.dataset = Dataset(
            self.test_image_size,
            base_folder=self.input_image_folder,
            data_format=self.data_format,
            debug_image_folder=os.path.join(self.output_folder, 'input_images')
            if self.save_all_input_images else None,
            **self.dataset_parameters)

        self.dataset_val = self.dataset.dataset_val_single_frame()
        self.video_frames = glob(self.input_image_folder + '/*.tif')
        self.video_frames = sorted([
            os.path.splitext(os.path.basename(frame))[0][1:]
            for frame in self.video_frames
        ])

        if self.data_format == 'channels_first':
            self.channel_axis = 1
            self.time_stack_axis = 1
        else:
            self.channel_axis = 3
            self.time_stack_axis = 0