Ejemplo n.º 1
0
    def init_networks(self):
        """
        Initialize networks and placeholders.
        """
        network_image_size = list(reversed(self.image_size))

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([('image', [1] + network_image_size),
                                                  ('single_label', [self.num_labels] + network_image_size),
                                                  ('single_heatmap', [1] + network_image_size)])
        else:
            data_generator_entries = OrderedDict([('image', network_image_size + [1]),
                                                  ('single_label', network_image_size + [self.num_labels]),
                                                  ('single_heatmap', network_image_size + [1])])

        data_generator_types = {'image':  tf.float32,
                                'labels': tf.uint8}

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGenerator(coord=self.coord, dataset=self.dataset_train, data_names_and_shapes=data_generator_entries, data_types=data_generator_types, batch_size=self.batch_size)
        data, mask, single_heatmap = self.train_queue.dequeue()
        data_heatmap_concat = tf.concat([data, single_heatmap], axis=1)
        prediction = training_net(data_heatmap_concat, num_labels=self.num_labels, is_training=True, actual_network=self.unet, padding=self.padding, **self.network_parameters)
        # losses
        self.loss_net = self.loss_function(labels=mask, logits=prediction, data_format=self.data_format)
        self.loss_reg = get_reg_loss(self.reg_constant)
        self.loss = self.loss_net + self.loss_reg

        # solver
        global_step = tf.Variable(self.current_iter, trainable=False)
        learning_rate = tf.train.piecewise_constant(global_step, self.learning_rate_boundaries, self.learning_rates)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        unclipped_gradients, variables = zip(*optimizer.compute_gradients(self.loss))
        norm = tf.global_norm(unclipped_gradients)
        if self.clip_gradient_global_norm > 0:
            gradients, _ = tf.clip_by_global_norm(unclipped_gradients, self.clip_gradient_global_norm)
        else:
            gradients = unclipped_gradients
        self.optimizer = optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)
        self.train_losses = OrderedDict([('loss', self.loss_net), ('loss_reg', self.loss_reg), ('gradient_norm', norm)])

        # build val graph
        self.data_val, self.mask_val, self.single_heatmap_val = create_placeholders_tuple(data_generator_entries, data_types=data_generator_types, shape_prefix=[1])
        self.data_heatmap_concat_val = tf.concat([self.data_val, self.single_heatmap_val], axis=1)
        self.prediction_val = training_net(self.data_heatmap_concat_val, num_labels=self.num_labels, is_training=False, actual_network=self.unet, padding=self.padding, **self.network_parameters)
        self.prediction_softmax_val = tf.nn.sigmoid(self.prediction_val)

        if self.has_validation_groundtruth:
            self.loss_val = self.loss_function(labels=self.mask_val, logits=self.prediction_val, data_format=self.data_format)
            self.val_losses = OrderedDict([('loss', self.loss_val), ('loss_reg', self.loss_reg), ('gradient_norm', tf.constant(0, tf.float32))])
Ejemplo n.º 2
0
    def initNetworks(self):
        net = tf.make_template('net', self.network)

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([('image', [self.image_channels] + list(reversed(self.image_size))),
                                                  ('landmarks', [self.num_landmarks, 3])])
            data_generator_entries_val = OrderedDict([('image', [self.image_channels] + list(reversed(self.image_size))),
                                                      ('landmarks', [self.num_landmarks, 3])])
        else:
            data_generator_entries = OrderedDict([('image', list(reversed(self.image_size)) + [self.image_channels]),
                                                  ('landmarks', [self.num_landmarks, 3])])
            data_generator_entries_val = OrderedDict([('image', list(reversed(self.image_size)) + [self.image_channels]),
                                                      ('landmarks', [self.num_landmarks, 3])])

        sigmas = tf.get_variable('sigmas', [self.num_landmarks], initializer=tf.constant_initializer(self.heatmap_sigma))
        sigmas_list = [(f's{i}', sigmas[i]) for i in range(self.num_landmarks)]

        # build training graph
        self.train_queue = DataGenerator(self.dataset_train, self.coord, data_generator_entries, batch_size=self.batch_size, n_threads=8)
        placeholders = self.train_queue.dequeue()
        image = placeholders[0]
        target_landmarks = placeholders[1]
        prediction = net(image, num_landmarks=self.num_landmarks, is_training=True, data_format=self.data_format)
        target_heatmaps = generate_heatmap_target(list(reversed(self.heatmap_size)), target_landmarks, sigmas, scale=self.sigma_scale, normalize=True, data_format=self.data_format)
        loss_sigmas = self.loss_sigmas(sigmas, target_landmarks)
        self.loss_reg = get_reg_loss(self.reg_constant)
        self.loss_net = self.loss_function(target_heatmaps, prediction)
        self.loss = self.loss_net + tf.cast(self.loss_reg, tf.float32) + loss_sigmas

        # optimizer
        global_step = tf.Variable(self.current_iter, trainable=False)
        optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=0.99, use_nesterov=True)
        unclipped_gradients, variables = zip(*optimizer.compute_gradients(self.loss))
        norm = tf.global_norm(unclipped_gradients)
        gradients, _ = tf.clip_by_global_norm(unclipped_gradients, 10000.0)
        self.optimizer = optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)
        self.train_losses = OrderedDict([('loss', self.loss_net), ('loss_reg', self.loss_reg), ('loss_sigmas', loss_sigmas), ('norm', norm)] + sigmas_list)

        # build val graph
        self.val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders(data_generator_entries_val, shape_prefix=[1])
        self.image_val = self.val_placeholders['image']
        self.target_landmarks_val = self.val_placeholders['landmarks']
        self.prediction_val = net(self.image_val, num_landmarks=self.num_landmarks, is_training=False, data_format=self.data_format)
        self.target_heatmaps_val = generate_heatmap_target(list(reversed(self.heatmap_size)), self.target_landmarks_val, sigmas, scale=self.sigma_scale, normalize=True, data_format=self.data_format)

        # losses
        self.loss_val = self.loss_function(self.target_heatmaps_val, self.prediction_val)
        self.val_losses = OrderedDict([('loss', self.loss_val), ('loss_reg', self.loss_reg), ('loss_sigmas', tf.constant(0, tf.float32)), ('norm', tf.constant(0, tf.float32))] + sigmas_list)
    def init_networks(self):
        """
        Initialize networks and placeholders.
        """
        network_image_size = list(reversed(self.image_size))

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([('image', [1] + network_image_size),
                                                  ('spine_heatmap', [1] + network_image_size)])
        else:
            data_generator_entries = OrderedDict([('image', network_image_size + [1]),
                                                  ('spine_heatmap', [1] + network_image_size)])

        data_generator_types = {'image': tf.float32,
                                'spine_heatmap': tf.float32}


        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGenerator(coord=self.coord, dataset=self.dataset_train, data_names_and_shapes=data_generator_entries, data_types=data_generator_types, batch_size=self.batch_size)
        data, target_spine_heatmap = self.train_queue.dequeue()

        prediction = training_net(data, num_labels=self.num_labels, is_training=True, actual_network=self.unet, padding=self.padding, **self.network_parameters)
        self.loss_net = self.loss_function(target=target_spine_heatmap, pred=prediction)
        self.loss_reg = get_reg_loss(self.reg_constant)
        self.loss = self.loss_net + tf.cast(self.loss_reg, tf.float32)

        # solver
        global_step = tf.Variable(self.current_iter, trainable=False)
        learning_rate = tf.train.piecewise_constant(global_step, self.learning_rate_boundaries, self.learning_rates)
        self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss, global_step=global_step)
        self.train_losses = OrderedDict([('loss', self.loss_net), ('loss_reg', self.loss_reg)])

        # build val graph
        self.data_val, self.target_spine_heatmap_val = create_placeholders_tuple(data_generator_entries, data_types=data_generator_types, shape_prefix=[1])
        self.prediction_val = training_net(self.data_val, num_labels=self.num_labels, is_training=False, actual_network=self.unet, padding=self.padding, **self.network_parameters)

        if self.has_validation_groundtruth:
            self.loss_val = self.loss_function(target=self.target_spine_heatmap_val, pred=self.prediction_val)
            self.val_losses = OrderedDict([('loss', self.loss_val), ('loss_reg', self.loss_reg)])
Ejemplo n.º 4
0
    def initNetworks(self):
        net = tf.make_template('net', self.network)

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([('image', [self.image_channels] + list(reversed(self.image_size))),
                                                  ('landmarks', [self.num_landmarks] + list(reversed(self.heatmap_size)))])
            data_generator_entries_val = OrderedDict([('image', [self.image_channels] + list(reversed(self.image_size))),
                                                      ('landmarks', [self.num_landmarks] + list(reversed(self.heatmap_size)))])
        else:
            raise NotImplementedError

        self.train_queue = DataGenerator(self.dataset_train, self.coord, data_generator_entries, batch_size=self.batch_size, n_threads=8)
        placeholders = self.train_queue.dequeue()
        image = placeholders[0]
        landmarks = placeholders[1]
        prediction = net(image, num_landmarks=self.num_landmarks, is_training=True, data_format=self.data_format)
        self.loss_net = self.loss_function(landmarks, prediction)

        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            if self.reg_constant > 0:
                reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
                self.loss_reg = self.reg_constant * tf.add_n(reg_losses)
                self.loss = self.loss_net + self.loss_reg
            else:
                self.loss_reg = 0
                self.loss = self.loss_net

        self.train_losses = OrderedDict([('loss', self.loss_net), ('loss_reg', self.loss_reg)])
        #self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss)
        self.optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=0.99, use_nesterov=True).minimize(self.loss)

        # build val graph
        self.val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders(data_generator_entries_val, shape_prefix=[1])
        self.image_val = self.val_placeholders['image']
        self.landmarks_val = self.val_placeholders['landmarks']
        self.prediction_val = net(self.image_val, num_landmarks=self.num_landmarks, is_training=False, data_format=self.data_format)

        # losses
        self.loss_val = self.loss_function(self.landmarks_val, self.prediction_val)
        self.val_losses = OrderedDict([('loss', self.loss_val), ('loss_reg', self.loss_reg)])
Ejemplo n.º 5
0
    def init_networks(self):
        """
        Initialize networks and placeholders.
        """
        network_image_size = list(reversed(self.image_size))

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image', [1] + network_image_size),
                ('landmarks', [self.num_landmarks, 4]),
                ('landmark_mask', [1] + network_image_size)
            ])
        else:
            data_generator_entries = OrderedDict([
                ('image', network_image_size + [1]),
                ('landmarks', [self.num_landmarks, 4]),
                ('landmark_mask', network_image_size + [1])
            ])

        data_generator_types = {'image': tf.float32}

        # create sigmas variable
        sigmas = tf.get_variable('sigmas', [self.num_landmarks],
                                 initializer=tf.constant_initializer(
                                     self.heatmap_sigma))
        if not self.learnable_sigma:
            sigmas = tf.stop_gradient(sigmas)
        mean_sigmas = tf.reduce_mean(sigmas)

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGenerator(
            coord=self.coord,
            dataset=self.dataset_train,
            data_names_and_shapes=data_generator_entries,
            data_types=data_generator_types,
            batch_size=self.batch_size)
        data, target_landmarks, landmark_mask = self.train_queue.dequeue()
        target_heatmaps = generate_heatmap_target(list(
            reversed(self.heatmap_size)),
                                                  target_landmarks,
                                                  sigmas,
                                                  scale=self.sigma_scale,
                                                  normalize=True,
                                                  data_format=self.data_format)
        prediction, local_prediction, spatial_prediction = training_net(
            data,
            num_labels=self.num_landmarks,
            is_training=True,
            actual_network=self.unet,
            padding=self.padding,
            **self.network_parameters)
        # losses
        self.loss_net = self.loss_function(target=target_heatmaps,
                                           pred=prediction,
                                           mask=landmark_mask)
        self.loss_sigmas = self.loss_function_sigmas(sigmas,
                                                     target_landmarks[0, :, 0])
        self.loss_reg = get_reg_loss(self.reg_constant)
        self.loss = self.loss_net + self.loss_reg + self.loss_sigmas

        # solver
        global_step = tf.Variable(self.current_iter, trainable=False)
        learning_rate = tf.train.piecewise_constant(
            global_step, self.learning_rate_boundaries, self.learning_rates)
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=0.99,
                                               use_nesterov=True)
        unclipped_gradients, variables = zip(
            *optimizer.compute_gradients(self.loss))
        norm = tf.global_norm(unclipped_gradients)
        if self.clip_gradient_global_norm > 0:
            gradients, _ = tf.clip_by_global_norm(
                unclipped_gradients, self.clip_gradient_global_norm)
        else:
            gradients = unclipped_gradients
        self.optimizer = optimizer.apply_gradients(zip(gradients, variables),
                                                   global_step=global_step)
        self.train_losses = OrderedDict([('loss', self.loss_net),
                                         ('loss_reg', self.loss_reg),
                                         ('loss_sigmas', self.loss_sigmas),
                                         ('mean_sigmas', mean_sigmas),
                                         ('gradient_norm', norm)])

        # build val graph
        self.data_val, self.target_landmarks_val, self.landmark_mask_val = create_placeholders_tuple(
            data_generator_entries,
            data_types=data_generator_types,
            shape_prefix=[1])
        self.target_heatmaps_val = generate_heatmap_target(
            list(reversed(self.heatmap_size)),
            self.target_landmarks_val,
            sigmas,
            scale=self.sigma_scale,
            normalize=True,
            data_format=self.data_format)
        self.prediction_val, self.local_prediction_val, self.spatial_prediction_val = training_net(
            self.data_val,
            num_labels=self.num_landmarks,
            is_training=False,
            actual_network=self.unet,
            padding=self.padding,
            **self.network_parameters)

        if self.has_validation_groundtruth:
            self.loss_val = self.loss_function(target=self.target_heatmaps_val,
                                               pred=self.prediction_val)
            self.val_losses = OrderedDict([
                ('loss', self.loss_val), ('loss_reg', self.loss_reg),
                ('loss_sigmas', tf.constant(0, tf.float32)),
                ('mean_sigmas', tf.constant(0, tf.float32)),
                ('gradient_norm', tf.constant(0, tf.float32))
            ])
Ejemplo n.º 6
0
class MainLoop(MainLoopBase):
    def __init__(self, cv, network_id):
        super().__init__()
        self.cv = cv
        self.network_id = network_id
        self.output_folder = network_id
        if cv != -1:
            self.output_folder += '_cv{}'.format(cv)
        self.output_folder += '/' + self.output_folder_timestamp()
        self.batch_size = 1
        learning_rates = {'scn': 0.00000005, 'unet': 0.000000005}
        max_iters = {'scn': 40000, 'unet': 80000}
        self.learning_rate = learning_rates[self.network_id]
        self.max_iter = max_iters[self.network_id]
        self.test_iter = 2500
        self.disp_iter = 100
        self.snapshot_iter = self.test_iter
        self.test_initialization = False
        self.current_iter = 0
        self.reg_constant = 0.0005
        self.sigma_regularization = 100
        self.sigma_scale = 1000
        self.invert_transformation = False
        self.num_landmarks = 26
        self.image_size = [96, 96, 192]
        self.image_spacing = [2, 2, 2]
        self.heatmap_size = self.image_size
        self.image_channels = 1
        self.heatmap_sigma = 4
        self.data_format = 'channels_first'
        self.save_debug_images = False
        self.base_folder = 'spine_localization_dataset'
        self.generate_landmarks = True
        self.cropped_training = True
        self.cropped_inc = [0, 64, 0, 0]
        if self.cropped_training:
            dataset = Dataset(self.image_size,
                              self.image_spacing,
                              self.heatmap_sigma,
                              self.num_landmarks,
                              self.base_folder,
                              self.cv,
                              self.data_format,
                              self.save_debug_images,
                              generate_heatmaps=not self.generate_landmarks,
                              generate_landmarks=self.generate_landmarks)
            self.dataset_train = dataset.dataset_train()
            dataset = Dataset(self.image_size,
                              self.image_spacing,
                              self.heatmap_sigma,
                              self.num_landmarks,
                              self.base_folder,
                              self.cv,
                              self.data_format,
                              self.save_debug_images,
                              generate_heatmaps=not self.generate_landmarks,
                              generate_landmarks=self.generate_landmarks)
            self.dataset_val = dataset.dataset_val()
        else:
            dataset = Dataset(self.image_size,
                              self.image_spacing,
                              self.heatmap_sigma,
                              self.num_landmarks,
                              self.base_folder,
                              self.cv,
                              self.data_format,
                              self.save_debug_images,
                              generate_heatmaps=not self.generate_landmarks,
                              generate_landmarks=self.generate_landmarks,
                              translate_by_random_factor=False)
            self.dataset_train = dataset.dataset_train()
            self.dataset_val = dataset.dataset_val()

        networks = {'scn': network_scn, 'unet': network_unet}
        self.network = networks[self.network_id]

        self.point_statistics_names = [
            'pe_mean', 'pe_stdev', 'pe_median', 'num_correct'
        ]
        self.additional_summaries_placeholders_val = dict([
            (name, create_summary_placeholder(name))
            for name in self.point_statistics_names
        ])

    def loss_function(self, pred, target):
        batch_size, _, _ = get_batch_channel_image_size(pred, self.data_format)
        return tf.nn.l2_loss(pred - target) / batch_size

    def initNetworks(self):
        net = tf.make_template('scn', self.network)

        if self.data_format == 'channels_first':
            if self.generate_landmarks:
                data_generator_entries = OrderedDict([
                    ('image',
                     [self.image_channels] + list(reversed(self.image_size))),
                    ('landmarks', [self.num_landmarks, 4])
                ])
            else:
                data_generator_entries = OrderedDict([
                    ('image',
                     [self.image_channels] + list(reversed(self.image_size))),
                    ('heatmaps',
                     [self.num_landmarks] + list(reversed(self.heatmap_size)))
                ])
        else:
            if self.generate_landmarks:
                data_generator_entries = OrderedDict([
                    ('image',
                     list(reversed(self.image_size)) + [self.image_channels]),
                    ('landmarks', [self.num_landmarks, 4])
                ])
            else:
                data_generator_entries = OrderedDict([
                    ('image',
                     list(reversed(self.image_size)) + [self.image_channels]),
                    ('heatmaps',
                     list(reversed(self.heatmap_size)) + [self.num_landmarks])
                ])

        sigmas = tf.get_variable('sigmas', [self.num_landmarks],
                                 initializer=tf.constant_initializer(
                                     self.heatmap_sigma))
        mean_sigmas = tf.reduce_mean(sigmas)
        self.train_queue = DataGenerator(self.dataset_train,
                                         self.coord,
                                         data_generator_entries,
                                         batch_size=self.batch_size)
        #self.train_queue = DataGeneratorDataset(self.dataset_train, data_generator_entries, batch_size=self.batch_size)
        placeholders = self.train_queue.dequeue()
        image = placeholders[0]

        if self.generate_landmarks:
            target_landmarks = placeholders[1]
            target_heatmaps = generate_heatmap_target(
                list(reversed(self.heatmap_size)),
                target_landmarks,
                sigmas,
                scale=self.sigma_scale,
                normalize=True,
                data_format=self.data_format)
            loss_sigmas = self.sigma_regularization * tf.nn.l2_loss(
                sigmas * target_landmarks[0, :, 0])
        else:
            target_heatmaps = placeholders[1]
            loss_sigmas = self.sigma_regularization * tf.nn.l2_loss(sigmas)
        heatmaps, _, _ = net(image,
                             num_heatmaps=self.num_landmarks,
                             is_training=True,
                             data_format=self.data_format)
        self.loss_net = self.loss_function(heatmaps, target_heatmaps)

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            if self.reg_constant > 0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                self.loss_reg = self.reg_constant * tf.add_n(reg_losses)
                self.loss = self.loss_net + self.loss_reg + loss_sigmas
            else:
                self.loss = self.loss_net + loss_sigmas

        self.train_losses = OrderedDict([('loss', self.loss_net),
                                         ('loss_reg', self.loss_reg),
                                         ('loss_sigmas', loss_sigmas),
                                         ('mean_sigmas', mean_sigmas)])
        # self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss)
        self.optimizer = tf.train.MomentumOptimizer(
            learning_rate=self.learning_rate, momentum=0.99,
            use_nesterov=True).minimize(self.loss)

        # build val graph
        val_placeholders = create_placeholders(data_generator_entries,
                                               shape_prefix=[1])
        self.image_val = val_placeholders['image']
        if self.generate_landmarks:
            self.target_landmarks_val = val_placeholders['landmarks']
            self.target_heatmaps_val = generate_heatmap_target(
                list(reversed(self.heatmap_size)),
                self.target_landmarks_val,
                sigmas,
                scale=self.sigma_scale,
                normalize=True,
                data_format=self.data_format)
            loss_sigmas_val = self.sigma_regularization * tf.nn.l2_loss(
                sigmas * self.target_landmarks_val[0, :, 0])
        else:
            self.target_heatmaps_val = val_placeholders['heatmaps']
            loss_sigmas_val = self.sigma_regularization * tf.nn.l2_loss(sigmas)
        self.heatmaps_val, self.heatmals_local_val, self.heatmaps_global_val = net(
            self.image_val,
            num_heatmaps=self.num_landmarks,
            is_training=False,
            data_format=self.data_format)

        # losses
        self.loss_val = self.loss_function(self.heatmaps_val,
                                           self.target_heatmaps_val)
        self.val_losses = OrderedDict([('loss', self.loss_val),
                                       ('loss_reg', self.loss_reg),
                                       ('loss_sigmas', loss_sigmas_val),
                                       ('mean_sigmas', mean_sigmas)])

    def test_full_image(self, dataset_entry):
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        heatmap_transform = transformations['image']
        feed_dict = {
            self.image_val: np.expand_dims(generators['image'], axis=0)
        }
        if self.generate_landmarks:
            feed_dict[self.target_landmarks_val] = np.expand_dims(
                generators['landmarks'], axis=0)
        else:
            feed_dict[self.target_heatmaps_val] = np.expand_dims(
                generators['heatmaps'], axis=0)

        # run loss and update loss accumulators
        run_tuple = self.sess.run(
            (self.heatmaps_val, self.target_heatmaps_val, self.loss_val) +
            self.val_loss_aggregator.get_update_ops(), feed_dict)
        heatmaps = np.squeeze(run_tuple[0], axis=0)
        image = generators['image']

        return image, heatmaps, heatmap_transform

    def test_cropped_image(self, dataset_entry):
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        heatmap_transform = transformations['image']

        image_size_np = [1] + list(reversed(self.image_size))
        heatmap_size_np = [self.num_landmarks] + list(reversed(
            self.image_size))
        full_image = generators['image']
        landmarks = generators['landmarks']
        image_tiler = ImageTiler(full_image.shape, image_size_np,
                                 self.cropped_inc, True, -1)
        landmark_tiler = LandmarkTiler(full_image.shape, image_size_np,
                                       self.cropped_inc)
        heatmap_tiler = ImageTiler(
            (self.num_landmarks, ) + full_image.shape[1:], heatmap_size_np,
            self.cropped_inc, True, 0)

        for image_tiler, landmark_tiler, heatmap_tiler in zip(
                image_tiler, landmark_tiler, heatmap_tiler):
            current_image = image_tiler.get_current_data(full_image)
            current_landmarks = landmark_tiler.get_current_data(landmarks)
            feed_dict = {
                self.image_val:
                np.expand_dims(current_image, axis=0),
                self.target_landmarks_val:
                np.expand_dims(current_landmarks, axis=0)
            }
            run_tuple = self.sess.run(
                (self.heatmaps_val, self.target_heatmaps_val, self.loss_val) +
                self.val_loss_aggregator.get_update_ops(), feed_dict)
            prediction = np.squeeze(run_tuple[0], axis=0)
            image_tiler.set_current_data(current_image)
            heatmap_tiler.set_current_data(prediction)

        return image_tiler.output_image, heatmap_tiler.output_image, heatmap_transform

    def test(self):
        print('Testing...')
        if self.data_format == 'channels_first':
            np_channel_index = 0
        else:
            np_channel_index = 3
        heatmap_maxima = HeatmapTest(np_channel_index, False)
        landmark_statistics = LandmarkStatistics()
        landmarks = {}
        for i in range(self.dataset_val.num_entries()):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            image_datasource = datasources['image_datasource']
            landmarks_datasource = datasources['landmarks_datasource']

            if not self.cropped_training:
                image, heatmaps, heatmap_transform = self.test_full_image(
                    dataset_entry)
            else:
                image, heatmaps, heatmap_transform = self.test_cropped_image(
                    dataset_entry)

            utils.io.image.write_np(
                ShiftScaleClamp(scale=255, clamp_min=0,
                                clamp_max=255)(heatmaps).astype(np.uint8),
                self.output_file_for_current_iteration(current_id +
                                                       '_heatmaps.mha'))
            utils.io.image.write_np(
                image,
                self.output_file_for_current_iteration(current_id +
                                                       '_image.mha'))

            predicted_landmarks = heatmap_maxima.get_landmarks(
                heatmaps, image_datasource, self.image_spacing,
                heatmap_transform)
            landmarks[current_id] = predicted_landmarks
            landmark_statistics.add_landmarks(current_id, predicted_landmarks,
                                              landmarks_datasource)

            tensorflow_train.utils.tensorflow_util.print_progress_bar(
                i,
                self.dataset_val.num_entries(),
                prefix='Testing ',
                suffix=' complete')

        tensorflow_train.utils.tensorflow_util.print_progress_bar(
            self.dataset_val.num_entries(),
            self.dataset_val.num_entries(),
            prefix='Testing ',
            suffix=' complete')
        print(landmark_statistics.get_pe_overview_string())
        print(landmark_statistics.get_correct_id_string(20.0))
        summary_values = OrderedDict(
            zip(
                self.point_statistics_names,
                list(landmark_statistics.get_pe_statistics()) +
                [landmark_statistics.get_correct_id(20)]))

        # finalize loss values
        self.val_loss_aggregator.finalize(self.current_iter, summary_values)
        utils.io.landmark.save_points_csv(
            landmarks, self.output_file_for_current_iteration('points.csv'))
        overview_string = landmark_statistics.get_overview_string(
            [2, 2.5, 3, 4, 10, 20], 10, 20.0)
        utils.io.text.save_string_txt(
            overview_string,
            self.output_file_for_current_iteration('eval.txt'))
Ejemplo n.º 7
0
    def initNetworks(self):
        net = tf.make_template('scn', self.network)

        if self.data_format == 'channels_first':
            if self.generate_landmarks:
                data_generator_entries = OrderedDict([
                    ('image',
                     [self.image_channels] + list(reversed(self.image_size))),
                    ('landmarks', [self.num_landmarks, 4])
                ])
            else:
                data_generator_entries = OrderedDict([
                    ('image',
                     [self.image_channels] + list(reversed(self.image_size))),
                    ('heatmaps',
                     [self.num_landmarks] + list(reversed(self.heatmap_size)))
                ])
        else:
            if self.generate_landmarks:
                data_generator_entries = OrderedDict([
                    ('image',
                     list(reversed(self.image_size)) + [self.image_channels]),
                    ('landmarks', [self.num_landmarks, 4])
                ])
            else:
                data_generator_entries = OrderedDict([
                    ('image',
                     list(reversed(self.image_size)) + [self.image_channels]),
                    ('heatmaps',
                     list(reversed(self.heatmap_size)) + [self.num_landmarks])
                ])

        sigmas = tf.get_variable('sigmas', [self.num_landmarks],
                                 initializer=tf.constant_initializer(
                                     self.heatmap_sigma))
        mean_sigmas = tf.reduce_mean(sigmas)
        self.train_queue = DataGenerator(self.dataset_train,
                                         self.coord,
                                         data_generator_entries,
                                         batch_size=self.batch_size)
        #self.train_queue = DataGeneratorDataset(self.dataset_train, data_generator_entries, batch_size=self.batch_size)
        placeholders = self.train_queue.dequeue()
        image = placeholders[0]

        if self.generate_landmarks:
            target_landmarks = placeholders[1]
            target_heatmaps = generate_heatmap_target(
                list(reversed(self.heatmap_size)),
                target_landmarks,
                sigmas,
                scale=self.sigma_scale,
                normalize=True,
                data_format=self.data_format)
            loss_sigmas = self.sigma_regularization * tf.nn.l2_loss(
                sigmas * target_landmarks[0, :, 0])
        else:
            target_heatmaps = placeholders[1]
            loss_sigmas = self.sigma_regularization * tf.nn.l2_loss(sigmas)
        heatmaps, _, _ = net(image,
                             num_heatmaps=self.num_landmarks,
                             is_training=True,
                             data_format=self.data_format)
        self.loss_net = self.loss_function(heatmaps, target_heatmaps)

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            if self.reg_constant > 0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                self.loss_reg = self.reg_constant * tf.add_n(reg_losses)
                self.loss = self.loss_net + self.loss_reg + loss_sigmas
            else:
                self.loss = self.loss_net + loss_sigmas

        self.train_losses = OrderedDict([('loss', self.loss_net),
                                         ('loss_reg', self.loss_reg),
                                         ('loss_sigmas', loss_sigmas),
                                         ('mean_sigmas', mean_sigmas)])
        # self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss)
        self.optimizer = tf.train.MomentumOptimizer(
            learning_rate=self.learning_rate, momentum=0.99,
            use_nesterov=True).minimize(self.loss)

        # build val graph
        val_placeholders = create_placeholders(data_generator_entries,
                                               shape_prefix=[1])
        self.image_val = val_placeholders['image']
        if self.generate_landmarks:
            self.target_landmarks_val = val_placeholders['landmarks']
            self.target_heatmaps_val = generate_heatmap_target(
                list(reversed(self.heatmap_size)),
                self.target_landmarks_val,
                sigmas,
                scale=self.sigma_scale,
                normalize=True,
                data_format=self.data_format)
            loss_sigmas_val = self.sigma_regularization * tf.nn.l2_loss(
                sigmas * self.target_landmarks_val[0, :, 0])
        else:
            self.target_heatmaps_val = val_placeholders['heatmaps']
            loss_sigmas_val = self.sigma_regularization * tf.nn.l2_loss(sigmas)
        self.heatmaps_val, self.heatmals_local_val, self.heatmaps_global_val = net(
            self.image_val,
            num_heatmaps=self.num_landmarks,
            is_training=False,
            data_format=self.data_format)

        # losses
        self.loss_val = self.loss_function(self.heatmaps_val,
                                           self.target_heatmaps_val)
        self.val_losses = OrderedDict([('loss', self.loss_val),
                                       ('loss_reg', self.loss_reg),
                                       ('loss_sigmas', loss_sigmas_val),
                                       ('mean_sigmas', mean_sigmas)])
Ejemplo n.º 8
0
class MainLoop(MainLoopBase):
    def __init__(self, network_id, cv, landmark_source, sigma_regularization, output_folder_name=''):
        super().__init__()
        self.network_id = network_id
        self.output_folder = os.path.join('output', network_id, landmark_source, cv if cv >= 0 else 'all', output_folder_name, self.output_folder_timestamp())
        self.batch_size = 1
        self.max_iter = 30000
        self.learning_rate = 0.000001
        self.test_iter = 5000
        self.disp_iter = 100
        self.snapshot_iter = self.test_iter
        self.test_initialization = False
        self.current_iter = 0
        self.reg_constant = 0.001
        self.cv = cv
        self.landmark_source = landmark_source
        original_image_extend = [193.5, 240.0]
        image_sizes = {'unet': [512, 512],
                       'scn_mia': [512, 512]}
        heatmap_sizes = {'unet': [512, 512],
                         'scn_mia': [512, 512]}
        sigmas = {'unet': 2.5,
                  'scn_mia': 2.5}
        self.image_size = image_sizes[self.network_id]
        self.heatmap_size = heatmap_sizes[self.network_id]
        self.image_spacing = [float(np.max([e / s for e, s in zip(original_image_extend, self.image_size)]))] * 2
        self.sigma = sigmas[self.network_id]
        self.image_channels = 1
        self.num_landmarks = 19
        self.heatmap_sigma = self.sigma
        self.sigma_regularization = sigma_regularization
        self.sigma_scale = 100.0
        self.data_format = 'channels_first'
        self.save_debug_images = False
        self.base_folder = './'
        dataset_parameters = {'image_size': self.image_size,
                              'heatmap_size': self.heatmap_size,
                              'image_spacing': self.image_spacing,
                              'num_landmarks': self.num_landmarks,
                              'base_folder': self.base_folder,
                              'data_format': self.data_format,
                              'save_debug_images': self.save_debug_images,
                              'cv': self.cv,
                              'landmark_source': self.landmark_source}

        dataset = Dataset(**dataset_parameters)
        self.dataset_train = dataset.dataset_train()
        self.dataset_val = dataset.dataset_val()

        networks = {'unet': network_unet,
                    'scn_mia': network_scn_mia}
        self.network = networks[self.network_id]
        self.landmark_metrics = ['pe_mean', 'pe_std', 'pe_median', 'or2', 'or25', 'or3', 'or4', 'or10']
        self.landmark_metric_prefixes = ['challenge', 'senior', 'junior', 'mean']
        self.additional_summaries_placeholders_val = OrderedDict([(prefix + '_' + name, create_summary_placeholder(prefix + '_' + name)) for name in self.landmark_metrics for prefix in self.landmark_metric_prefixes])

    def loss_function(self, target, prediction):
        return tf.nn.l2_loss(target - prediction) / get_batch_channel_image_size(target, self.data_format)[0]

    def loss_sigmas(self, sigmas, landmarks):
        return self.sigma_regularization * tf.nn.l2_loss(sigmas[None, :] * landmarks[:, :, 0]) / landmarks.get_shape().as_list()[0]

    def initNetworks(self):
        net = tf.make_template('net', self.network)

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([('image', [self.image_channels] + list(reversed(self.image_size))),
                                                  ('landmarks', [self.num_landmarks, 3])])
            data_generator_entries_val = OrderedDict([('image', [self.image_channels] + list(reversed(self.image_size))),
                                                      ('landmarks', [self.num_landmarks, 3])])
        else:
            data_generator_entries = OrderedDict([('image', list(reversed(self.image_size)) + [self.image_channels]),
                                                  ('landmarks', [self.num_landmarks, 3])])
            data_generator_entries_val = OrderedDict([('image', list(reversed(self.image_size)) + [self.image_channels]),
                                                      ('landmarks', [self.num_landmarks, 3])])

        sigmas = tf.get_variable('sigmas', [self.num_landmarks], initializer=tf.constant_initializer(self.heatmap_sigma))
        sigmas_list = [(f's{i}', sigmas[i]) for i in range(self.num_landmarks)]

        # build training graph
        self.train_queue = DataGenerator(self.dataset_train, self.coord, data_generator_entries, batch_size=self.batch_size, n_threads=8)
        placeholders = self.train_queue.dequeue()
        image = placeholders[0]
        target_landmarks = placeholders[1]
        prediction = net(image, num_landmarks=self.num_landmarks, is_training=True, data_format=self.data_format)
        target_heatmaps = generate_heatmap_target(list(reversed(self.heatmap_size)), target_landmarks, sigmas, scale=self.sigma_scale, normalize=True, data_format=self.data_format)
        loss_sigmas = self.loss_sigmas(sigmas, target_landmarks)
        self.loss_reg = get_reg_loss(self.reg_constant)
        self.loss_net = self.loss_function(target_heatmaps, prediction)
        self.loss = self.loss_net + tf.cast(self.loss_reg, tf.float32) + loss_sigmas

        # optimizer
        global_step = tf.Variable(self.current_iter, trainable=False)
        optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=0.99, use_nesterov=True)
        unclipped_gradients, variables = zip(*optimizer.compute_gradients(self.loss))
        norm = tf.global_norm(unclipped_gradients)
        gradients, _ = tf.clip_by_global_norm(unclipped_gradients, 10000.0)
        self.optimizer = optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)
        self.train_losses = OrderedDict([('loss', self.loss_net), ('loss_reg', self.loss_reg), ('loss_sigmas', loss_sigmas), ('norm', norm)] + sigmas_list)

        # build val graph
        self.val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders(data_generator_entries_val, shape_prefix=[1])
        self.image_val = self.val_placeholders['image']
        self.target_landmarks_val = self.val_placeholders['landmarks']
        self.prediction_val = net(self.image_val, num_landmarks=self.num_landmarks, is_training=False, data_format=self.data_format)
        self.target_heatmaps_val = generate_heatmap_target(list(reversed(self.heatmap_size)), self.target_landmarks_val, sigmas, scale=self.sigma_scale, normalize=True, data_format=self.data_format)

        # losses
        self.loss_val = self.loss_function(self.target_heatmaps_val, self.prediction_val)
        self.val_losses = OrderedDict([('loss', self.loss_val), ('loss_reg', self.loss_reg), ('loss_sigmas', tf.constant(0, tf.float32)), ('norm', tf.constant(0, tf.float32))] + sigmas_list)

    def test_full_image(self, dataset_entry):
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        feed_dict = {self.val_placeholders['image']: np.expand_dims(generators['image'], axis=0),
                     self.val_placeholders['landmarks']: np.expand_dims(generators['landmarks'], axis=0)}

        # run loss and update loss accumulators
        run_tuple = self.sess.run((self.prediction_val, self.target_heatmaps_val, self.loss_val) + self.val_loss_aggregator.get_update_ops(), feed_dict)
        prediction = np.squeeze(run_tuple[0], axis=0)
        target_heatmaps = np.squeeze(run_tuple[1], axis=0)
        image = generators['image']
        transformation = transformations['image']

        return image, prediction, target_heatmaps, transformation

    def finalize_landmark_statistics(self, landmark_statistics, prefix):
        pe_mean, pe_std, pe_median = landmark_statistics.get_pe_statistics()
        or2, or25, or3, or4, or10 = landmark_statistics.get_num_outliers([2.0, 2.5, 3.0, 4.0, 10.0], True)
        print(prefix + '_pe', ['{0:.3f}'.format(s) for s in [pe_mean, pe_std, pe_median]])
        print(prefix + '_outliers', ['{0:.3f}'.format(s) for s in [or2, or25, or3, or4, or10]])
        overview_string = landmark_statistics.get_overview_string([2, 2.5, 3, 4, 10, 20], 10, 20.0)
        utils.io.text.save_string_txt(overview_string, self.output_file_for_current_iteration(prefix + '_eval.txt'))
        additional_summaries = {prefix + '_pe_mean': pe_mean,
                                prefix + '_pe_std': pe_std,
                                prefix + '_pe_median': pe_median,
                                prefix + '_or2': or2,
                                prefix + '_or25': or25,
                                prefix + '_or3': or3,
                                prefix + '_or4': or4,
                                prefix + '_or10': or10}
        return additional_summaries

    def test(self):
        heatmap_test = HeatmapTest(channel_axis=0, invert_transformation=False)
        challenge_landmark_statistics = LandmarkStatistics()
        senior_landmark_statistics = LandmarkStatistics()
        junior_landmark_statistics = LandmarkStatistics()
        mean_landmark_statistics = LandmarkStatistics()

        landmarks = {}
        for i in range(self.dataset_val.num_entries()):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            reference_image = datasources['image_datasource']
            groundtruth_challenge_landmarks = datasources['challenge_landmarks_datasource']
            groundtruth_senior_landmarks = datasources['senior_landmarks_datasource']
            groundtruth_junior_landmarks = datasources['junior_landmarks_datasource']
            groundtruth_mean_landmarks = datasources['mean_landmarks_datasource']
            image, prediction, target_heatmaps, transform = self.test_full_image(dataset_entry)

            utils.io.image.write_np(image, self.output_file_for_current_iteration(current_id + '_image.mha'))
            utils.io.image.write_np(prediction, self.output_file_for_current_iteration(current_id + '_prediction.mha'))
            utils.io.image.write_np(target_heatmaps, self.output_file_for_current_iteration(current_id + '_target_heatmap.mha'))
            predicted_landmarks = heatmap_test.get_landmarks(prediction, reference_image, output_spacing=self.image_spacing, transformation=transform)
            tensorflow_train.utils.tensorflow_util.print_progress_bar(i, self.dataset_val.num_entries())
            landmarks[current_id] = predicted_landmarks
            challenge_landmark_statistics.add_landmarks(current_id, predicted_landmarks, groundtruth_challenge_landmarks)
            senior_landmark_statistics.add_landmarks(current_id, predicted_landmarks, groundtruth_senior_landmarks)
            junior_landmark_statistics.add_landmarks(current_id, predicted_landmarks, groundtruth_junior_landmarks)
            mean_landmark_statistics.add_landmarks(current_id, predicted_landmarks, groundtruth_mean_landmarks)

        tensorflow_train.utils.tensorflow_util.print_progress_bar(self.dataset_val.num_entries(), self.dataset_val.num_entries())
        challenge_summaries = self.finalize_landmark_statistics(challenge_landmark_statistics, 'challenge')
        senior_summaries = self.finalize_landmark_statistics(senior_landmark_statistics, 'senior')
        junior_summaries = self.finalize_landmark_statistics(junior_landmark_statistics, 'junior')
        mean_summaries = self.finalize_landmark_statistics(mean_landmark_statistics, 'mean')
        additional_summaries = OrderedDict(chain(senior_summaries.items(), junior_summaries.items(), challenge_summaries.items(), mean_summaries.items()))

        # finalize loss values
        self.val_loss_aggregator.finalize(self.current_iter, additional_summaries)
        utils.io.landmark.save_points_csv(landmarks, self.output_file_for_current_iteration('prediction.csv'))
Ejemplo n.º 9
0
class MainLoop(MainLoopBase):
    def __init__(self,
                 cv,
                 network,
                 unet,
                 network_parameters,
                 learning_rate,
                 output_folder_name=''):
        """
        Initializer.
        :param cv: The cv fold. 0, 1, 2 for CV; 'train_all' for training on whole dataset.
        :param network: The used network. Usually network_u.
        :param unet: The specific instance of the U-Net. Usually UnetClassicAvgLinear3d.
        :param network_parameters: The network parameters passed to unet.
        :param learning_rate: The initial learning rate.
        :param output_folder_name: The output folder name that is used for distinguishing experiments.
        """
        super().__init__()
        self.batch_size = 1
        self.learning_rates = [
            learning_rate, learning_rate * 0.5, learning_rate * 0.1
        ]
        self.learning_rate_boundaries = [10000, 15000]
        self.max_iter = 20000
        self.test_iter = 5000
        self.disp_iter = 100
        self.snapshot_iter = 5000
        self.test_initialization = False
        self.current_iter = 0
        self.reg_constant = 0.0005
        self.use_background = True
        self.num_labels = 1
        self.heatmap_sigma = 2.0
        self.data_format = 'channels_first'
        self.network = network
        self.unet = unet
        self.network_parameters = network_parameters
        self.padding = 'same'

        self.use_pyro_dataset = False
        self.save_output_images = True
        self.save_output_images_as_uint = True  # set to False, if you want to see the direct network output
        self.save_debug_images = False
        self.has_validation_groundtruth = cv in [0, 1, 2]
        self.local_base_folder = '../verse2019_dataset'
        self.image_size = [64, 64, 128]
        self.image_spacing = [8] * 3
        self.output_folder = os.path.join('./output/spine_localization/',
                                          network.__name__, unet.__name__,
                                          output_folder_name, str(cv),
                                          self.output_folder_timestamp())
        dataset_parameters = {
            'base_folder': self.local_base_folder,
            'image_size': self.image_size,
            'image_spacing': self.image_spacing,
            'cv': cv,
            'input_gaussian_sigma': 3.0,
            'generate_spine_heatmap': True,
            'save_debug_images': self.save_debug_images
        }

        dataset = Dataset(**dataset_parameters)
        if self.use_pyro_dataset:
            server_name = '@localhost:51232'
            uri = 'PYRO:verse_dataset' + server_name
            print('using pyro uri', uri)
            self.dataset_train = PyroClientDataset(uri, **dataset_parameters)
        else:
            self.dataset_train = dataset.dataset_train()
        self.dataset_val = dataset.dataset_val()

        self.point_statistics_names = ['pe_mean', 'pe_stdev', 'pe_median']
        self.additional_summaries_placeholders_val = dict([
            (name, create_summary_placeholder(name))
            for name in self.point_statistics_names
        ])

    def loss_function(self, pred, target):
        """
        L2 loss function calculated with prediction and target.
        :param pred: The predicted image.
        :param target: The target image.
        :return: L2 loss of (pred - target) / batch_size
        """
        batch_size, _, _ = get_batch_channel_image_size(pred, self.data_format)
        return tf.nn.l2_loss(pred - target) / batch_size

    def init_networks(self):
        """
        Initialize networks and placeholders.
        """
        network_image_size = list(reversed(self.image_size))

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image', [1] + network_image_size),
                ('spine_heatmap', [1] + network_image_size)
            ])
        else:
            data_generator_entries = OrderedDict([
                ('image', network_image_size + [1]),
                ('spine_heatmap', [1] + network_image_size)
            ])

        data_generator_types = {
            'image': tf.float32,
            'spine_heatmap': tf.float32
        }

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGenerator(
            coord=self.coord,
            dataset=self.dataset_train,
            data_names_and_shapes=data_generator_entries,
            data_types=data_generator_types,
            batch_size=self.batch_size)
        data, target_spine_heatmap = self.train_queue.dequeue()

        prediction = training_net(data,
                                  num_labels=self.num_labels,
                                  is_training=True,
                                  actual_network=self.unet,
                                  padding=self.padding,
                                  **self.network_parameters)
        self.loss_net = self.loss_function(target=target_spine_heatmap,
                                           pred=prediction)
        self.loss_reg = get_reg_loss(self.reg_constant)
        self.loss = self.loss_net + tf.cast(self.loss_reg, tf.float32)

        # solver
        global_step = tf.Variable(self.current_iter, trainable=False)
        learning_rate = tf.train.piecewise_constant(
            global_step, self.learning_rate_boundaries, self.learning_rates)
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=learning_rate).minimize(self.loss,
                                                  global_step=global_step)
        self.train_losses = OrderedDict([('loss', self.loss_net),
                                         ('loss_reg', self.loss_reg)])

        # build val graph
        self.data_val, self.target_spine_heatmap_val = create_placeholders_tuple(
            data_generator_entries,
            data_types=data_generator_types,
            shape_prefix=[1])
        self.prediction_val = training_net(self.data_val,
                                           num_labels=self.num_labels,
                                           is_training=False,
                                           actual_network=self.unet,
                                           padding=self.padding,
                                           **self.network_parameters)

        if self.has_validation_groundtruth:
            self.loss_val = self.loss_function(
                target=self.target_spine_heatmap_val, pred=self.prediction_val)
            self.val_losses = OrderedDict([('loss', self.loss_val),
                                           ('loss_reg', self.loss_reg)])

    def test_full_image(self, dataset_entry):
        """
        Perform inference on a dataset_entry with the validation network.
        :param dataset_entry: A dataset entry from the dataset.
        :return: input image (np.array), network prediction (np.array), transformation (sitk.Transform)
        """
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        if self.has_validation_groundtruth:
            feed_dict = {
                self.data_val:
                np.expand_dims(generators['image'], axis=0),
                self.target_spine_heatmap_val:
                np.expand_dims(generators['spine_heatmap'], axis=0)
            }
            # run loss and update loss accumulators
            run_tuple = self.sess.run(
                (self.prediction_val, self.loss_val) +
                self.val_loss_aggregator.get_update_ops(),
                feed_dict=feed_dict)
        else:
            feed_dict = {
                self.data_val: np.expand_dims(generators['image'], axis=0)
            }
            # run loss and update loss accumulators
            run_tuple = self.sess.run((self.prediction_val, ),
                                      feed_dict=feed_dict)

        prediction = np.squeeze(run_tuple[0], axis=0)
        transformation = transformations['image']
        image = generators['image']

        return image, prediction, transformation

    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3

        landmark_statistics = LandmarkStatistics()
        landmarks = {}
        num_entries = self.dataset_val.num_entries()
        for i in range(num_entries):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            if self.has_validation_groundtruth:
                groundtruth_landmarks = datasources['landmarks']
                groundtruth_landmark = [
                    get_mean_landmark(groundtruth_landmarks)
                ]
            input_image = datasources['image']

            image, prediction, transformation = self.test_full_image(
                dataset_entry)
            predictions_sitk = utils.sitk_image.transform_np_output_to_sitk_input(
                output_image=prediction,
                output_spacing=self.image_spacing,
                channel_axis=channel_axis,
                input_image_sitk=input_image,
                transform=transformation,
                interpolator='linear',
                output_pixel_type=sitk.sitkFloat32)
            if self.save_output_images:
                if self.save_output_images_as_uint:
                    image_normalization = 'min_max'
                    heatmap_normalization = (0, 1)
                    output_image_type = np.uint8
                else:
                    image_normalization = None
                    heatmap_normalization = None
                    output_image_type = np.float32
                origin = transformation.TransformPoint(np.zeros(3, np.float64))
                utils.io.image.write_multichannel_np(
                    image,
                    self.output_file_for_current_iteration(current_id +
                                                           '_input.mha'),
                    output_normalization_mode=image_normalization,
                    data_format=self.data_format,
                    image_type=output_image_type,
                    spacing=self.image_spacing,
                    origin=origin)
                utils.io.image.write_multichannel_np(
                    prediction,
                    self.output_file_for_current_iteration(current_id +
                                                           '_prediction.mha'),
                    output_normalization_mode=heatmap_normalization,
                    data_format=self.data_format,
                    image_type=output_image_type,
                    spacing=self.image_spacing,
                    origin=origin)
                #utils.io.image.write(predictions_sitk[0], self.output_file_for_current_iteration(current_id + '_prediction_original.mha'))

            predictions_com = input_image.TransformContinuousIndexToPhysicalPoint(
                list(
                    reversed(
                        utils.np_image.center_of_mass(
                            utils.sitk_np.sitk_to_np_no_copy(
                                predictions_sitk[0])))))
            current_landmark = [Landmark(predictions_com)]
            landmarks[current_id] = current_landmark

            if self.has_validation_groundtruth:
                landmark_statistics.add_landmarks(current_id, current_landmark,
                                                  groundtruth_landmark)

            print_progress_bar(i,
                               num_entries,
                               prefix='Testing ',
                               suffix=' complete')

        utils.io.landmark.save_points_csv(
            landmarks, self.output_file_for_current_iteration('points.csv'))

        # finalize loss values
        if self.has_validation_groundtruth:
            print(landmark_statistics.get_pe_overview_string())
            summary_values = OrderedDict(
                zip(self.point_statistics_names,
                    list(landmark_statistics.get_pe_statistics())))

            # finalize loss values
            self.val_loss_aggregator.finalize(self.current_iter,
                                              summary_values)
            overview_string = landmark_statistics.get_overview_string(
                [2, 2.5, 3, 4, 10, 20], 10)
            utils.io.text.save_string_txt(
                overview_string,
                self.output_file_for_current_iteration('eval.txt'))
    def init_networks(self):
        network_image_size = self.image_size

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image', [1] + network_image_size),
                ('labels', [self.num_labels] + network_image_size)
            ])
        else:
            data_generator_entries = OrderedDict([
                ('image', network_image_size + [1]),
                ('labels', network_image_size + [self.num_labels])
            ])

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGenerator(self.dataset_train,
                                         self.coord,
                                         data_generator_entries,
                                         batch_size=self.batch_size)
        data, mask = self.train_queue.dequeue()
        prediction, _, _ = training_net(data,
                                        num_labels=self.num_labels,
                                        is_training=True,
                                        data_format=self.data_format)
        # losses
        self.loss_net = self.loss_function(labels=mask,
                                           logits=prediction,
                                           data_format=self.data_format)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            if self.reg_constant > 0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                self.loss_reg = self.reg_constant * tf.add_n(reg_losses)
                self.loss = self.loss_net + self.loss_reg
            else:
                self.loss = self.loss_net

        self.train_losses = OrderedDict([('loss', self.loss_net),
                                         ('loss_reg', self.loss_reg)])

        # solver
        global_step = tf.Variable(self.current_iter)
        learning_rate = tf.train.piecewise_constant(
            global_step, self.learning_rate_boundaries, self.learning_rates)
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=learning_rate).minimize(self.loss,
                                                  global_step=global_step)

        # build val graph
        val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders(
            data_generator_entries, shape_prefix=[1])
        self.data_val = val_placeholders['image']
        self.prediction_val, self.local_prediction_val, self.spatial_prediction_val = training_net(
            self.data_val,
            num_labels=self.num_labels,
            is_training=False,
            data_format=self.data_format)
        self.prediction_softmax_val = tf.nn.softmax(
            self.prediction_val,
            axis=1 if self.data_format == 'channels_first' else 4)

        if self.has_validation_groundtruth:
            self.mask_val = val_placeholders['labels']
            # losses
            self.loss_val = self.loss_function(labels=self.mask_val,
                                               logits=self.prediction_val,
                                               data_format=self.data_format)
            self.val_losses = OrderedDict([('loss', self.loss_val),
                                           ('loss_reg', self.loss_reg)])
class MainLoop(MainLoopBase):
    def __init__(self, modality, cv):
        super().__init__()
        self.modality = modality
        self.cv = cv
        self.batch_size = 1
        self.learning_rates = [0.00001, 0.000001]
        self.learning_rate_boundaries = [20000]
        self.max_iter = 40000
        self.test_iter = 5000
        self.disp_iter = 100
        self.snapshot_iter = self.test_iter
        self.test_initialization = False
        self.current_iter = 0
        self.reg_constant = 0.0001
        self.num_labels = 8
        self.data_format = 'channels_first'
        self.channel_axis = 1
        self.save_debug_images = False

        self.has_validation_groundtruth = cv != 0
        self.base_folder = 'mmwhs_dataset'
        self.image_size = [64, 64, 64]
        if modality == 'ct':
            self.image_spacing = [3, 3, 3]
        else:
            self.image_spacing = [4, 4, 4]
        self.input_gaussian_sigma = 1.0
        self.label_gaussian_sigma = 1.0
        self.use_landmarks = True

        self.output_folder = './output/scn_' + modality + '_' + str(
            cv) + '/' + self.output_folder_timestamp()

        dataset_parameters = dict(
            image_size=self.image_size,
            image_spacing=self.image_spacing,
            base_folder=self.base_folder,
            cv=self.cv,
            modality=self.modality,
            input_gaussian_sigma=self.input_gaussian_sigma,
            label_gaussian_sigma=self.label_gaussian_sigma,
            use_landmarks=self.use_landmarks,
            num_labels=self.num_labels,
            data_format=self.data_format,
            save_debug_images=self.save_debug_images)

        self.dataset = Dataset(**dataset_parameters)

        self.dataset_train = self.dataset.dataset_train()
        self.dataset_val = self.dataset.dataset_val()
        self.dataset_val = self.dataset.dataset_val()
        self.files_to_copy = ['main.py', 'network.py', 'dataset.py']
        self.dice_names = list(
            map(lambda x: 'dice_{}'.format(x), range(self.num_labels)))
        self.additional_summaries_placeholders_val = dict([
            (name, create_summary_placeholder(name))
            for name in self.dice_names
        ])
        self.loss_function = softmax_cross_entropy_with_logits
        self.network = network_scn

    def init_networks(self):
        network_image_size = self.image_size

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image', [1] + network_image_size),
                ('labels', [self.num_labels] + network_image_size)
            ])
        else:
            data_generator_entries = OrderedDict([
                ('image', network_image_size + [1]),
                ('labels', network_image_size + [self.num_labels])
            ])

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGenerator(self.dataset_train,
                                         self.coord,
                                         data_generator_entries,
                                         batch_size=self.batch_size)
        data, mask = self.train_queue.dequeue()
        prediction, _, _ = training_net(data,
                                        num_labels=self.num_labels,
                                        is_training=True,
                                        data_format=self.data_format)
        # losses
        self.loss_net = self.loss_function(labels=mask,
                                           logits=prediction,
                                           data_format=self.data_format)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            if self.reg_constant > 0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                self.loss_reg = self.reg_constant * tf.add_n(reg_losses)
                self.loss = self.loss_net + self.loss_reg
            else:
                self.loss = self.loss_net

        self.train_losses = OrderedDict([('loss', self.loss_net),
                                         ('loss_reg', self.loss_reg)])

        # solver
        global_step = tf.Variable(self.current_iter)
        learning_rate = tf.train.piecewise_constant(
            global_step, self.learning_rate_boundaries, self.learning_rates)
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=learning_rate).minimize(self.loss,
                                                  global_step=global_step)

        # build val graph
        val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders(
            data_generator_entries, shape_prefix=[1])
        self.data_val = val_placeholders['image']
        self.prediction_val, self.local_prediction_val, self.spatial_prediction_val = training_net(
            self.data_val,
            num_labels=self.num_labels,
            is_training=False,
            data_format=self.data_format)
        self.prediction_softmax_val = tf.nn.softmax(
            self.prediction_val,
            axis=1 if self.data_format == 'channels_first' else 4)

        if self.has_validation_groundtruth:
            self.mask_val = val_placeholders['labels']
            # losses
            self.loss_val = self.loss_function(labels=self.mask_val,
                                               logits=self.prediction_val,
                                               data_format=self.data_format)
            self.val_losses = OrderedDict([('loss', self.loss_val),
                                           ('loss_reg', self.loss_reg)])

    def test(self):
        print('Testing...')
        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3
        labels = list(range(self.num_labels))
        segmentation_test = SegmentationTest(labels,
                                             channel_axis=channel_axis,
                                             interpolator='cubic',
                                             largest_connected_component=False,
                                             all_labels_are_connected=False)
        segmentation_statistics = SegmentationStatistics(
            labels,
            self.output_folder_for_current_iteration(),
            metrics={'dice': DiceMetric()})
        num_entries = self.dataset_val.num_entries()
        for i in range(num_entries):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            generators = dataset_entry['generators']
            transformations = dataset_entry['transformations']
            if self.has_validation_groundtruth:
                feed_dict = {
                    self.data_val: np.expand_dims(generators['image'], axis=0),
                    self.mask_val: np.expand_dims(generators['labels'], axis=0)
                }
                # run loss and update loss accumulators
                run_tuple = self.sess.run(
                    (self.prediction_softmax_val, self.local_prediction_val,
                     self.spatial_prediction_val, self.loss_val) +
                    self.val_loss_aggregator.get_update_ops(),
                    feed_dict=feed_dict)
            else:
                feed_dict = {
                    self.data_val: np.expand_dims(generators['data'], axis=0)
                }
                # run loss and update loss accumulators
                run_tuple = self.sess.run((self.prediction_softmax_val, ),
                                          feed_dict=feed_dict)

            prediction = np.squeeze(run_tuple[0], axis=0)
            #local_prediction = np.squeeze(run_tuple[1], axis=0)
            #spatial_prediction = np.squeeze(run_tuple[2], axis=0)
            input = datasources['image']
            transformation = transformations['image']
            prediction_labels, prediction_sitk = segmentation_test.get_label_image(
                prediction,
                input,
                self.image_spacing,
                transformation,
                return_transformed_sitk=True)
            utils.io.image.write(
                prediction_labels,
                self.output_file_for_current_iteration(current_id + '.mha'))
            origin = transformation.TransformPoint(np.zeros(3, np.float64))
            utils.io.image.write_multichannel_np(
                prediction,
                self.output_file_for_current_iteration(current_id +
                                                       '_prediction.mha'),
                output_normalization_mode=(0, 1),
                data_format=self.data_format,
                image_type=np.uint8,
                spacing=self.image_spacing,
                origin=origin)
            #utils.io.image.write_multichannel_np(local_prediction, self.output_file_for_current_iteration(current_id + '_local_prediction.mha'), output_normalization_mode=(0, 1), data_format=self.data_format, image_type=np.uint8, spacing=self.image_spacing, origin=origin)
            #utils.io.image.write_multichannel_np(spatial_prediction, self.output_file_for_current_iteration(current_id + '_spatial_prediction.mha'), output_normalization_mode=(0, 1), data_format=self.data_format, image_type=np.uint8, spacing=self.image_spacing, origin=origin)
            if self.has_validation_groundtruth:
                groundtruth = datasources['labels']
                segmentation_statistics.add_labels(current_id,
                                                   prediction_labels,
                                                   groundtruth)
            tensorflow_train.utils.tensorflow_util.print_progress_bar(
                i, num_entries, prefix='Testing ', suffix=' complete')

        # finalize loss values
        if self.has_validation_groundtruth:
            segmentation_statistics.finalize()
            dice_list = segmentation_statistics.get_metric_mean_list('dice')
            dice_dict = OrderedDict(list(zip(self.dice_names, dice_list)))
            self.val_loss_aggregator.finalize(self.current_iter,
                                              summary_values=dice_dict)
    def initNetworks(self):
        network_image_size = list(reversed(self.image_size))
        global_step = tf.Variable(self.current_iter)

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('data', [1] + network_image_size),
                ('mask', [self.num_labels] + network_image_size)
            ])

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGenerator(self.dataset_train,
                                         self.coord,
                                         data_generator_entries,
                                         batch_size=self.batch_size)
        data, mask = self.train_queue.dequeue()
        with tf.variable_scope('training', reuse=False):
            if self.network.__name__ is 'network_ud' or self.network.__name__ is 'SegCaps_multilabels':
                prediction = training_net(data,
                                          num_labels=self.num_labels,
                                          is_training=True,
                                          data_format=self.data_format)
            else:
                prediction = training_net(data,
                                          routing_type=self.routing_type,
                                          num_labels=self.num_labels,
                                          is_training=True,
                                          data_format=self.data_format)

        #print parameters count
        logging.info('------------')
        var_num = np.sum([
            np.product([xi.value for xi in x.get_shape()])
            for x in tf.global_variables()
        ])
        logging.info('Net number of parameter : ' + str(var_num))

        # losses
        if 'spread_loss' in self.loss_function.__name__:
            self.loss_net = self.loss_function(labels=mask,
                                               logits=prediction,
                                               global_step=global_step,
                                               data_format=self.data_format)
        else:
            self.loss_net = self.loss_function(labels=mask,
                                               logits=prediction,
                                               data_format=self.data_format)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            self.loss = self.loss_net

        self.train_losses = OrderedDict([('loss', self.loss_net)])

        # solver
        self.optimizer = tf.train.AdadeltaOptimizer(learning_rate=1).minimize(
            self.loss,
            global_step=global_step,
            var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                       scope='training'))

        # build val graph
        val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders(
            data_generator_entries, shape_prefix=[1])
        self.data_val = val_placeholders['data']

        with tf.variable_scope('testing', reuse=True):

            if self.network.__name__ is 'network_ud' or self.network.__name__ is 'SegCaps_multilabels':
                self.prediction_val = training_net(
                    self.data_val,
                    num_labels=self.num_labels,
                    is_training=False,
                    data_format=self.data_format)
            else:
                self.prediction_val = training_net(
                    self.data_val,
                    routing_type=self.routing_type,
                    num_labels=self.num_labels,
                    is_training=False,
                    data_format=self.data_format)
            self.mask_val = val_placeholders['mask']

            # losses
            self.loss_val = self.loss_function(labels=self.mask_val,
                                               logits=self.prediction_val,
                                               data_format=self.data_format)
            self.val_losses = OrderedDict([('loss', self.loss_val)])
class MainLoop(MainLoopBase):
    def __init__(self, param):
        super().__init__()
        #polyaxon
        data_dir = os.path.join(
            list(get_data_paths().values())[0], "lung/JSRT/preprocessed/")
        logging.info('DATA DIR = ' + data_dir)
        output_path = get_outputs_path()

        self.loss_function = param[0]
        self.network = param[1]
        self.routing_type = param[2]

        self.batch_size = 1
        self.learning_rates = [1, 1]
        self.max_iter = 300000
        self.test_iter = 10000
        self.disp_iter = 100
        self.snapshot_iter = self.test_iter
        self.test_initialization = False
        self.current_iter = 0
        self.num_labels = 6
        self.data_format = 'channels_first'  #WARNING: Capsule might not work with channel last !
        self.channel_axis = 1
        self.save_debug_images = False
        self.base_folder = data_dir  ##input folder
        self.image_size = [128, 128]
        self.image_spacing = [1, 1]
        self.output_folder = output_path + self.network.__name__ + '_' + self.output_folder_timestamp(
        )  ##output save
        self.dataset = Dataset(image_size=self.image_size,
                               image_spacing=self.image_spacing,
                               num_labels=self.num_labels,
                               base_folder=self.base_folder,
                               data_format=self.data_format,
                               save_debug_images=self.save_debug_images)

        self.dataset_train = self.dataset.dataset_train()
        self.dataset_train.get_next()
        self.dataset_val = self.dataset.dataset_val()
        self.dice_names = list(
            map(lambda x: 'dice_{}'.format(x), range(self.num_labels)))
        self.additional_summaries_placeholders_val = dict([
            (name, create_summary_placeholder(name))
            for name in self.dice_names
        ])

        if self.network.__name__ is 'network_ud':
            self.net_file = './Lung_Segmentation/LungSeg/cnn_network.py'
        elif self.network.__name__ is 'SegCaps_multilabels':
            self.net_file = './Lung_Segmentation/LungSeg/SegCaps/SegCaps.py'
        else:
            self.net_file = './Lung_Segmentation/LungSeg/capsule_network.py'
        self.files_to_copy = ['main_train_and_test.py', self.net_file]

    def initNetworks(self):
        network_image_size = list(reversed(self.image_size))
        global_step = tf.Variable(self.current_iter)

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('data', [1] + network_image_size),
                ('mask', [self.num_labels] + network_image_size)
            ])

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGenerator(self.dataset_train,
                                         self.coord,
                                         data_generator_entries,
                                         batch_size=self.batch_size)
        data, mask = self.train_queue.dequeue()
        with tf.variable_scope('training', reuse=False):
            if self.network.__name__ is 'network_ud' or self.network.__name__ is 'SegCaps_multilabels':
                prediction = training_net(data,
                                          num_labels=self.num_labels,
                                          is_training=True,
                                          data_format=self.data_format)
            else:
                prediction = training_net(data,
                                          routing_type=self.routing_type,
                                          num_labels=self.num_labels,
                                          is_training=True,
                                          data_format=self.data_format)

        #print parameters count
        logging.info('------------')
        var_num = np.sum([
            np.product([xi.value for xi in x.get_shape()])
            for x in tf.global_variables()
        ])
        logging.info('Net number of parameter : ' + str(var_num))

        # losses
        if 'spread_loss' in self.loss_function.__name__:
            self.loss_net = self.loss_function(labels=mask,
                                               logits=prediction,
                                               global_step=global_step,
                                               data_format=self.data_format)
        else:
            self.loss_net = self.loss_function(labels=mask,
                                               logits=prediction,
                                               data_format=self.data_format)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            self.loss = self.loss_net

        self.train_losses = OrderedDict([('loss', self.loss_net)])

        # solver
        self.optimizer = tf.train.AdadeltaOptimizer(learning_rate=1).minimize(
            self.loss,
            global_step=global_step,
            var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                       scope='training'))

        # build val graph
        val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders(
            data_generator_entries, shape_prefix=[1])
        self.data_val = val_placeholders['data']

        with tf.variable_scope('testing', reuse=True):

            if self.network.__name__ is 'network_ud' or self.network.__name__ is 'SegCaps_multilabels':
                self.prediction_val = training_net(
                    self.data_val,
                    num_labels=self.num_labels,
                    is_training=False,
                    data_format=self.data_format)
            else:
                self.prediction_val = training_net(
                    self.data_val,
                    routing_type=self.routing_type,
                    num_labels=self.num_labels,
                    is_training=False,
                    data_format=self.data_format)
            self.mask_val = val_placeholders['mask']

            # losses
            self.loss_val = self.loss_function(labels=self.mask_val,
                                               logits=self.prediction_val,
                                               data_format=self.data_format)
            self.val_losses = OrderedDict([('loss', self.loss_val)])

    def test(self):
        logging.info('Testing...')
        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3
        labels = list(range(self.num_labels))
        segmentation_test = SegmentationTest(labels,
                                             channel_axis=channel_axis,
                                             interpolator='cubic',
                                             largest_connected_component=False,
                                             all_labels_are_connected=False)
        segmentation_statistics = SegmentationStatistics(
            labels,
            self.output_folder_for_current_iteration(),
            metrics={'dice': DiceMetric()})
        num_entries = self.dataset_val.num_entries()
        for i in range(num_entries):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            generators = dataset_entry['generators']
            transformations = dataset_entry['transformations']

            feed_dict = {
                self.data_val: np.expand_dims(generators['data'], axis=0),
                self.mask_val: np.expand_dims(generators['mask'], axis=0)
            }
            # run loss and update loss accumulators
            run_tuple = self.sess.run(
                (self.prediction_val, self.loss_val) +
                self.val_loss_aggregator.get_update_ops(),
                feed_dict=feed_dict)

            prediction = np.squeeze(run_tuple[0], axis=0)
            input = datasources['image']
            transformation = transformations['data']
            prediction_labels, prediction_sitk = segmentation_test.get_label_image(
                prediction,
                input,
                self.image_spacing,
                transformation,
                return_transformed_sitk=True)
            utils.io.image.write_np_colormask(
                prediction_labels,
                self.output_file_for_current_iteration(current_id + '.png'))
            utils.io.image.write_np(
                prediction,
                self.output_file_for_current_iteration(current_id +
                                                       '_prediction.mha'))

            groundtruth = datasources['mask']
            segmentation_statistics.add_labels(current_id, prediction_labels,
                                               groundtruth)
            tensorflow_train.utils.tensorflow_util.print_progress_bar(
                i, num_entries, prefix='Testing ', suffix=' complete')

        # finalize loss values
        segmentation_statistics.finalize()
        dice_list = segmentation_statistics.get_metric_mean_list('dice')
        dice_dict = OrderedDict(list(zip(self.dice_names, dice_list)))
        self.val_loss_aggregator.finalize(self.current_iter,
                                          summary_values=dice_dict)
Ejemplo n.º 14
0
class MainLoop(MainLoopBase):
    def __init__(self, cv, modality):
        super().__init__()
        self.cv = cv
        self.output_folder = './mmwhs_localization/{}_{}'.format(
            modality, cv) + '/' + self.output_folder_timestamp()
        self.batch_size = 1
        self.learning_rate = 0.00001
        self.learning_rates = [self.learning_rate, self.learning_rate * 0.1]
        self.learning_rate_boundaries = [10000]
        self.max_iter = 20000
        self.test_iter = 5000
        self.disp_iter = 100
        self.snapshot_iter = self.test_iter
        self.test_initialization = False
        self.current_iter = 0
        self.reg_constant = 0.00005
        self.invert_transformation = False
        self.image_size = [32] * 3
        if modality == 'ct':
            self.image_spacing = [10] * 3
        else:
            self.image_spacing = [12] * 3
        self.sigma = [1.5] * 3
        self.image_channels = 1
        self.num_landmarks = 1
        self.data_format = 'channels_first'
        self.save_debug_images = False
        self.local_base_folder = '../../semantic_segmentation/mmwhs/mmwhs_dataset'
        dataset_parameters = {
            'base_folder': self.local_base_folder,
            'image_size': self.image_size,
            'image_spacing': self.image_spacing,
            'cv': cv,
            'input_gaussian_sigma': 4.0,
            'modality': modality,
            'save_debug_images': self.save_debug_images
        }

        dataset = Dataset(**dataset_parameters)
        self.dataset_train = dataset.dataset_train()
        self.dataset_val = dataset.dataset_val()
        self.network = network_unet
        self.loss_function = lambda x, y: tf.nn.l2_loss(
            x - y) / get_batch_channel_image_size(x, self.data_format)[0]

    def initNetworks(self):
        net = tf.make_template('net', self.network)

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image',
                 [self.image_channels] + list(reversed(self.image_size))),
                ('landmarks',
                 [self.num_landmarks] + list(reversed(self.image_size)))
            ])
            data_generator_entries_val = OrderedDict([
                ('image',
                 [self.image_channels] + list(reversed(self.image_size))),
                ('landmarks',
                 [self.num_landmarks] + list(reversed(self.image_size)))
            ])
        else:
            raise NotImplementedError

        self.train_queue = DataGenerator(self.dataset_train,
                                         self.coord,
                                         data_generator_entries,
                                         batch_size=self.batch_size)
        image, landmarks = self.train_queue.dequeue()
        prediction = net(image,
                         num_landmarks=self.num_landmarks,
                         is_training=True,
                         data_format=self.data_format)
        self.loss_net = self.loss_function(landmarks, prediction)

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            if self.reg_constant > 0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                self.loss_reg = self.reg_constant * tf.add_n(reg_losses)
                self.loss = self.loss_net + self.loss_reg
            else:
                self.loss_reg = 0
                self.loss = self.loss_net

        self.train_losses = OrderedDict([('loss', self.loss_net),
                                         ('loss_reg', self.loss_reg)])
        global_step = tf.Variable(self.current_iter)
        learning_rate = tf.train.piecewise_constant(
            global_step, self.learning_rate_boundaries, self.learning_rates)
        #self.optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.99, use_nesterov=True).minimize(self.loss, global_step=global_step)
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=learning_rate).minimize(self.loss,
                                                  global_step=global_step)

        # build val graph
        self.val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders(
            data_generator_entries_val, shape_prefix=[1])
        self.image_val = self.val_placeholders['image']
        self.landmarks_val = self.val_placeholders['landmarks']
        self.prediction_val = net(self.image_val,
                                  num_landmarks=self.num_landmarks,
                                  is_training=False,
                                  data_format=self.data_format)

        # losses
        self.loss_val = self.loss_function(self.landmarks_val,
                                           self.prediction_val)
        self.val_losses = OrderedDict([('loss', self.loss_val),
                                       ('loss_reg', self.loss_reg)])

    def test_full_image(self, dataset_entry):
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        feed_dict = {
            self.val_placeholders['image']:
            np.expand_dims(generators['image'], axis=0),
            self.val_placeholders['landmarks']:
            np.expand_dims(generators['landmarks'], axis=0)
        }

        # run loss and update loss accumulators
        run_tuple = self.sess.run((self.prediction_val, self.loss_val) +
                                  self.val_loss_aggregator.get_update_ops(),
                                  feed_dict)
        prediction = np.squeeze(run_tuple[0], axis=0)
        image = generators['image']
        transformation = transformations['image']

        return image, prediction, transformation

    def test(self):
        heatmap_test = HeatmapTest(channel_axis=0, invert_transformation=False)
        landmark_statistics = LandmarkStatistics()

        landmarks = {}
        for i in range(self.dataset_val.num_entries()):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            reference_image = datasources['image']
            groundtruth_landmarks = datasources['landmarks']
            image, prediction, transform = self.test_full_image(dataset_entry)

            utils.io.image.write_np(
                (prediction * 128).astype(np.int8),
                self.output_file_for_current_iteration(current_id +
                                                       '_heatmap.mha'))
            predicted_landmarks = heatmap_test.get_landmarks(
                prediction,
                reference_image,
                transformation=transform,
                output_spacing=self.image_spacing)
            tensorflow_train.utils.tensorflow_util.print_progress_bar(
                i, self.dataset_val.num_entries())
            landmarks[current_id] = predicted_landmarks
            landmark_statistics.add_landmarks(current_id, predicted_landmarks,
                                              groundtruth_landmarks)

        tensorflow_train.utils.tensorflow_util.print_progress_bar(
            self.dataset_val.num_entries(), self.dataset_val.num_entries())
        overview_string = landmark_statistics.get_overview_string(
            [2.0, 4.0, 10.0])
        print(overview_string)

        # finalize loss values
        self.val_loss_aggregator.finalize(self.current_iter)
        utils.io.landmark.save_points_csv(
            landmarks,
            self.output_file_for_current_iteration('prediction.csv'))
        utils.io.text.save_string_txt(
            overview_string,
            self.output_file_for_current_iteration('summary.txt'))
Ejemplo n.º 15
0
class MainLoop(MainLoopBase):
    def __init__(self,
                 cv,
                 network,
                 unet,
                 network_parameters,
                 learning_rate,
                 output_folder_name=''):
        """
        Initializer.
        :param cv: The cv fold. 0, 1, 2 for CV; 'train_all' for training on whole dataset.
        :param network: The used network. Usually network_u.
        :param unet: The specific instance of the U-Net. Usually UnetClassicAvgLinear3d.
        :param network_parameters: The network parameters passed to unet.
        :param learning_rate: The initial learning rate.
        :param output_folder_name: The output folder name that is used for distinguishing experiments.
        """
        super().__init__()
        self.batch_size = 1
        self.learning_rates = [
            learning_rate, learning_rate * 0.5, learning_rate * 0.1
        ]
        self.learning_rate_boundaries = [50000, 75000]
        self.max_iter = 100000
        self.test_iter = 10000
        self.disp_iter = 100
        self.snapshot_iter = 5000
        self.test_initialization = False
        self.current_iter = 0
        self.reg_constant = 0.0005
        self.use_background = True
        self.num_landmarks = 25
        self.heatmap_sigma = 4.0
        self.learnable_sigma = True
        self.data_format = 'channels_first'
        self.network = network
        self.unet = unet
        self.network_parameters = network_parameters
        self.padding = 'same'
        self.clip_gradient_global_norm = 100000.0

        self.use_pyro_dataset = False
        self.use_spine_postprocessing = True
        self.save_output_images = True
        self.save_output_images_as_uint = True  # set to False, if you want to see the direct network output
        self.save_debug_images = False
        self.has_validation_groundtruth = cv in [0, 1, 2]
        self.local_base_folder = '../verse2019_dataset'
        self.image_size = [96, 96, 128]
        self.image_spacing = [2] * 3
        self.cropped_inc = [0, 96, 0, 0]
        self.heatmap_size = self.image_size
        self.sigma_regularization = 100
        self.sigma_scale = 1000.0
        self.cropped_training = True
        self.output_folder = os.path.join('./output/vertebrae_localization/',
                                          network.__name__, unet.__name__,
                                          output_folder_name, str(cv),
                                          self.output_folder_timestamp())
        dataset_parameters = {
            'base_folder': self.local_base_folder,
            'image_size': self.image_size,
            'image_spacing': self.image_spacing,
            'cv': cv,
            'input_gaussian_sigma': 0.75,
            'generate_landmarks': True,
            'generate_landmark_mask': True,
            'translate_to_center_landmarks': True,
            'translate_by_random_factor': True,
            'save_debug_images': self.save_debug_images
        }

        dataset = Dataset(**dataset_parameters)
        if self.use_pyro_dataset:
            server_name = '@localhost:51232'
            uri = 'PYRO:verse_dataset' + server_name
            print('using pyro uri', uri)
            self.dataset_train = PyroClientDataset(uri, **dataset_parameters)
        else:
            self.dataset_train = dataset.dataset_train()
        self.dataset_val = dataset.dataset_val()

        self.point_statistics_names = [
            'pe_mean', 'pe_stdev', 'pe_median', 'num_correct'
        ]
        self.additional_summaries_placeholders_val = dict([
            (name, create_summary_placeholder(name))
            for name in self.point_statistics_names
        ])

    def loss_function(self, pred, target, mask=None):
        """
        L2 loss function calculated with prediction and target.
        :param pred: The predicted image.
        :param target: The target image.
        :param mask: If not none, calculate loss only pixels, where mask == 1
        :return: L2 loss of (pred - target) / batch_size
        """
        batch_size, _, _ = get_batch_channel_image_size(pred, self.data_format)
        if mask is not None:
            return tf.nn.l2_loss((pred - target) * mask) / batch_size
        else:
            return tf.nn.l2_loss(pred - target) / batch_size

    def loss_function_sigmas(self, sigmas, valid_landmarks):
        """
        L2 loss function for sigmas. Only calculated for values ver valid_landmarks == 1.
        :param sigmas: Sigma variables.
        :param valid_landmarks: Valid landmarks. Needs to have same shape as sigmas.
        :return: L2 loss of sigmas * valid_landmarks.
        """
        return self.sigma_regularization * tf.nn.l2_loss(
            sigmas * valid_landmarks)

    def init_networks(self):
        """
        Initialize networks and placeholders.
        """
        network_image_size = list(reversed(self.image_size))

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image', [1] + network_image_size),
                ('landmarks', [self.num_landmarks, 4]),
                ('landmark_mask', [1] + network_image_size)
            ])
        else:
            data_generator_entries = OrderedDict([
                ('image', network_image_size + [1]),
                ('landmarks', [self.num_landmarks, 4]),
                ('landmark_mask', network_image_size + [1])
            ])

        data_generator_types = {'image': tf.float32}

        # create sigmas variable
        sigmas = tf.get_variable('sigmas', [self.num_landmarks],
                                 initializer=tf.constant_initializer(
                                     self.heatmap_sigma))
        if not self.learnable_sigma:
            sigmas = tf.stop_gradient(sigmas)
        mean_sigmas = tf.reduce_mean(sigmas)

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGenerator(
            coord=self.coord,
            dataset=self.dataset_train,
            data_names_and_shapes=data_generator_entries,
            data_types=data_generator_types,
            batch_size=self.batch_size)
        data, target_landmarks, landmark_mask = self.train_queue.dequeue()
        target_heatmaps = generate_heatmap_target(list(
            reversed(self.heatmap_size)),
                                                  target_landmarks,
                                                  sigmas,
                                                  scale=self.sigma_scale,
                                                  normalize=True,
                                                  data_format=self.data_format)
        prediction, local_prediction, spatial_prediction = training_net(
            data,
            num_labels=self.num_landmarks,
            is_training=True,
            actual_network=self.unet,
            padding=self.padding,
            **self.network_parameters)
        # losses
        self.loss_net = self.loss_function(target=target_heatmaps,
                                           pred=prediction,
                                           mask=landmark_mask)
        self.loss_sigmas = self.loss_function_sigmas(sigmas,
                                                     target_landmarks[0, :, 0])
        self.loss_reg = get_reg_loss(self.reg_constant)
        self.loss = self.loss_net + self.loss_reg + self.loss_sigmas

        # solver
        global_step = tf.Variable(self.current_iter, trainable=False)
        learning_rate = tf.train.piecewise_constant(
            global_step, self.learning_rate_boundaries, self.learning_rates)
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=0.99,
                                               use_nesterov=True)
        unclipped_gradients, variables = zip(
            *optimizer.compute_gradients(self.loss))
        norm = tf.global_norm(unclipped_gradients)
        if self.clip_gradient_global_norm > 0:
            gradients, _ = tf.clip_by_global_norm(
                unclipped_gradients, self.clip_gradient_global_norm)
        else:
            gradients = unclipped_gradients
        self.optimizer = optimizer.apply_gradients(zip(gradients, variables),
                                                   global_step=global_step)
        self.train_losses = OrderedDict([('loss', self.loss_net),
                                         ('loss_reg', self.loss_reg),
                                         ('loss_sigmas', self.loss_sigmas),
                                         ('mean_sigmas', mean_sigmas),
                                         ('gradient_norm', norm)])

        # build val graph
        self.data_val, self.target_landmarks_val, self.landmark_mask_val = create_placeholders_tuple(
            data_generator_entries,
            data_types=data_generator_types,
            shape_prefix=[1])
        self.target_heatmaps_val = generate_heatmap_target(
            list(reversed(self.heatmap_size)),
            self.target_landmarks_val,
            sigmas,
            scale=self.sigma_scale,
            normalize=True,
            data_format=self.data_format)
        self.prediction_val, self.local_prediction_val, self.spatial_prediction_val = training_net(
            self.data_val,
            num_labels=self.num_landmarks,
            is_training=False,
            actual_network=self.unet,
            padding=self.padding,
            **self.network_parameters)

        if self.has_validation_groundtruth:
            self.loss_val = self.loss_function(target=self.target_heatmaps_val,
                                               pred=self.prediction_val)
            self.val_losses = OrderedDict([
                ('loss', self.loss_val), ('loss_reg', self.loss_reg),
                ('loss_sigmas', tf.constant(0, tf.float32)),
                ('mean_sigmas', tf.constant(0, tf.float32)),
                ('gradient_norm', tf.constant(0, tf.float32))
            ])

    def test_cropped_image(self, dataset_entry):
        """
        Perform inference on a dataset_entry with the validation network. Performs cropped prediction and merges outputs as maxima.
        :param dataset_entry: A dataset entry from the dataset.
        :return: input image (np.array), target heatmaps (np.array), predicted heatmaps,  transformation (sitk.Transform)
        """
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        transformation = transformations['image']

        image_size_np = [1] + list(reversed(self.image_size))
        labels_size_np = [self.num_landmarks] + list(reversed(self.image_size))
        full_image = generators['image']
        landmarks = generators['landmarks']
        image_tiler = ImageTiler(full_image.shape, image_size_np,
                                 self.cropped_inc, True, -1)
        landmark_tiler = LandmarkTiler(full_image.shape, image_size_np,
                                       self.cropped_inc)
        prediction_tiler = ImageTiler(
            (self.num_landmarks, ) + full_image.shape[1:], labels_size_np,
            self.cropped_inc, True, 0)

        for image_tiler, landmark_tiler, prediction_tiler in zip(
                image_tiler, landmark_tiler, prediction_tiler):
            current_image = image_tiler.get_current_data(full_image)
            current_landmarks = landmark_tiler.get_current_data(landmarks)
            if self.has_validation_groundtruth:
                feed_dict = {
                    self.data_val:
                    np.expand_dims(current_image, axis=0),
                    self.target_landmarks_val:
                    np.expand_dims(current_landmarks, axis=0)
                }
                run_tuple = self.sess.run(
                    (self.prediction_val, self.loss_val) +
                    self.val_loss_aggregator.get_update_ops(),
                    feed_dict=feed_dict)
            else:
                feed_dict = {
                    self.data_val: np.expand_dims(current_image, axis=0)
                }
                run_tuple = self.sess.run((self.prediction_val, ),
                                          feed_dict=feed_dict)
            prediction = np.squeeze(run_tuple[0], axis=0)
            image_tiler.set_current_data(current_image)
            prediction_tiler.set_current_data(prediction)

        return image_tiler.output_image, prediction_tiler.output_image, transformation

    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3
        if self.use_spine_postprocessing:
            heatmap_maxima = HeatmapTest(channel_axis,
                                         False,
                                         return_multiple_maxima=True,
                                         min_max_distance=7,
                                         min_max_value=0.25,
                                         multiple_min_max_value_factor=0.1)
            spine_postprocessing = SpinePostprocessing(
                num_landmarks=self.num_landmarks,
                image_spacing=self.image_spacing)
        else:
            heatmap_maxima = HeatmapTest(channel_axis, False)

        landmark_statistics = LandmarkStatistics()
        landmarks = {}
        num_entries = self.dataset_val.num_entries()
        for i in range(num_entries):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            input_image = datasources['image']
            target_landmarks = datasources['landmarks']

            image, prediction, transformation = self.test_cropped_image(
                dataset_entry)

            if self.save_output_images:
                if self.save_output_images_as_uint:
                    image_normalization = 'min_max'
                    heatmap_normalization = (0, 1)
                    output_image_type = np.uint8
                else:
                    image_normalization = None
                    heatmap_normalization = None
                    output_image_type = np.float32
                origin = transformation.TransformPoint(np.zeros(3, np.float64))
                utils.io.image.write_multichannel_np(
                    image,
                    self.output_file_for_current_iteration(current_id +
                                                           '_input.mha'),
                    output_normalization_mode=image_normalization,
                    data_format=self.data_format,
                    image_type=output_image_type,
                    spacing=self.image_spacing,
                    origin=origin)
                utils.io.image.write_multichannel_np(
                    prediction,
                    self.output_file_for_current_iteration(current_id +
                                                           '_prediction.mha'),
                    output_normalization_mode=heatmap_normalization,
                    data_format=self.data_format,
                    image_type=output_image_type,
                    spacing=self.image_spacing,
                    origin=origin)

            if self.use_spine_postprocessing:
                local_maxima_landmarks = heatmap_maxima.get_landmarks(
                    prediction, input_image, self.image_spacing,
                    transformation)
                landmark_sequence = spine_postprocessing.postprocess_landmarks(
                    local_maxima_landmarks, prediction.shape)
                landmarks[current_id] = landmark_sequence
            else:
                maxima_landmarks = heatmap_maxima.get_landmarks(
                    prediction, input_image, self.image_spacing,
                    transformation)
                landmarks[current_id] = maxima_landmarks

            if self.has_validation_groundtruth:
                landmark_statistics.add_landmarks(current_id,
                                                  landmark_sequence,
                                                  target_landmarks)

            print_progress_bar(i,
                               num_entries,
                               prefix='Testing ',
                               suffix=' complete')

        utils.io.landmark.save_points_csv(
            landmarks, self.output_file_for_current_iteration('points.csv'))

        # finalize loss values
        if self.has_validation_groundtruth:
            print(landmark_statistics.get_pe_overview_string())
            print(landmark_statistics.get_correct_id_string(20.0))
            summary_values = OrderedDict(
                zip(
                    self.point_statistics_names,
                    list(landmark_statistics.get_pe_statistics()) +
                    [landmark_statistics.get_correct_id(20)]))

            # finalize loss values
            self.val_loss_aggregator.finalize(self.current_iter,
                                              summary_values)
            overview_string = landmark_statistics.get_overview_string(
                [2, 2.5, 3, 4, 10, 20], 10, 20.0)
            utils.io.text.save_string_txt(
                overview_string,
                self.output_file_for_current_iteration('eval.txt'))
Ejemplo n.º 16
0
class MainLoop(MainLoopBase):
    def __init__(self,
                 cv,
                 network,
                 unet,
                 network_parameters,
                 learning_rate,
                 output_folder_name=''):
        """
        Initializer.
        :param cv: The cv fold. 0, 1, 2 for CV; 'train_all' for training on whole dataset.
        :param network: The used network. Usually network_u.
        :param unet: The specific instance of the U-Net. Usually UnetClassicAvgLinear3d.
        :param network_parameters: The network parameters passed to unet.
        :param learning_rate: The initial learning rate.
        :param output_folder_name: The output folder name that is used for distinguishing experiments.
        """
        super().__init__()
        self.batch_size = 1
        self.learning_rates = [
            learning_rate, learning_rate * 0.5, learning_rate * 0.1
        ]
        self.learning_rate_boundaries = [20000, 30000]
        self.max_iter = 50000
        self.test_iter = 5000
        self.disp_iter = 100
        self.snapshot_iter = 5000
        self.test_initialization = False
        self.current_iter = 0
        self.reg_constant = 0.000001
        self.num_labels = 1
        self.num_labels_all = 26
        self.data_format = 'channels_first'
        self.channel_axis = 1
        self.network = network
        self.unet = unet
        self.network_parameters = network_parameters
        self.padding = 'same'
        self.clip_gradient_global_norm = 1.0

        self.use_pyro_dataset = False
        self.save_output_images = True
        self.save_output_images_as_uint = True  # set to False, if you want to see the direct network output
        self.save_debug_images = False
        self.has_validation_groundtruth = cv in [0, 1, 2]
        self.local_base_folder = '../verse2019_dataset'
        self.image_size = [128, 128, 96]
        self.image_spacing = [1] * 3
        self.output_folder = os.path.join('./output/vertebrae_segmentation/',
                                          network.__name__, unet.__name__,
                                          output_folder_name, str(cv),
                                          self.output_folder_timestamp())
        dataset_parameters = {
            'base_folder': self.local_base_folder,
            'image_size': self.image_size,
            'image_spacing': self.image_spacing,
            'cv': cv,
            'input_gaussian_sigma': 0.75,
            'label_gaussian_sigma': 1.0,
            'heatmap_sigma': 3.0,
            'generate_single_vertebrae_heatmap': True,
            'generate_single_vertebrae': True,
            'save_debug_images': self.save_debug_images
        }

        dataset = Dataset(**dataset_parameters)
        if self.use_pyro_dataset:
            server_name = '@localhost:51232'
            uri = 'PYRO:verse_dataset' + server_name
            print('using pyro uri', uri)
            self.dataset_train = PyroClientDataset(uri, **dataset_parameters)
        else:
            self.dataset_train = dataset.dataset_train()
        self.dataset_val = dataset.dataset_val()

        self.dice_names = ['mean_dice'] + list(
            map(lambda x: 'dice_{}'.format(x), range(self.num_labels_all)))
        self.hausdorff_names = ['mean_h'] + list(
            map(lambda x: 'h_{}'.format(x), range(self.num_labels)))
        self.additional_summaries_placeholders_val = dict([
            (name, create_summary_placeholder(name))
            for name in (self.dice_names + self.hausdorff_names)
        ])
        self.loss_function = sigmoid_cross_entropy_with_logits

        self.setup_base_folder = os.path.join(self.local_base_folder, 'setup')
        if cv in [0, 1, 2]:
            self.cv_folder = os.path.join(self.setup_base_folder,
                                          os.path.join('cv', str(cv)))
            self.test_file = os.path.join(self.cv_folder, 'val.txt')
        else:
            self.test_file = os.path.join(self.setup_base_folder,
                                          'train_all.txt')
        self.valid_landmarks_file = os.path.join(self.setup_base_folder,
                                                 'valid_landmarks.csv')
        self.test_id_list = utils.io.text.load_list(self.test_file)
        self.valid_landmarks = utils.io.text.load_dict_csv(
            self.valid_landmarks_file)

    def init_networks(self):
        """
        Initialize networks and placeholders.
        """
        network_image_size = list(reversed(self.image_size))

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image', [1] + network_image_size),
                ('single_label', [self.num_labels] + network_image_size),
                ('single_heatmap', [1] + network_image_size)
            ])
        else:
            data_generator_entries = OrderedDict([
                ('image', network_image_size + [1]),
                ('single_label', network_image_size + [self.num_labels]),
                ('single_heatmap', network_image_size + [1])
            ])

        data_generator_types = {'image': tf.float32, 'labels': tf.uint8}

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGenerator(
            coord=self.coord,
            dataset=self.dataset_train,
            data_names_and_shapes=data_generator_entries,
            data_types=data_generator_types,
            batch_size=self.batch_size)
        data, mask, single_heatmap = self.train_queue.dequeue()
        data_heatmap_concat = tf.concat([data, single_heatmap], axis=1)
        prediction = training_net(data_heatmap_concat,
                                  num_labels=self.num_labels,
                                  is_training=True,
                                  actual_network=self.unet,
                                  padding=self.padding,
                                  **self.network_parameters)
        # losses
        self.loss_net = self.loss_function(labels=mask,
                                           logits=prediction,
                                           data_format=self.data_format)
        self.loss_reg = get_reg_loss(self.reg_constant)
        self.loss = self.loss_net + self.loss_reg

        # solver
        global_step = tf.Variable(self.current_iter, trainable=False)
        learning_rate = tf.train.piecewise_constant(
            global_step, self.learning_rate_boundaries, self.learning_rates)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        unclipped_gradients, variables = zip(
            *optimizer.compute_gradients(self.loss))
        norm = tf.global_norm(unclipped_gradients)
        if self.clip_gradient_global_norm > 0:
            gradients, _ = tf.clip_by_global_norm(
                unclipped_gradients, self.clip_gradient_global_norm)
        else:
            gradients = unclipped_gradients
        self.optimizer = optimizer.apply_gradients(zip(gradients, variables),
                                                   global_step=global_step)
        self.train_losses = OrderedDict([('loss', self.loss_net),
                                         ('loss_reg', self.loss_reg),
                                         ('gradient_norm', norm)])

        # build val graph
        self.data_val, self.mask_val, self.single_heatmap_val = create_placeholders_tuple(
            data_generator_entries,
            data_types=data_generator_types,
            shape_prefix=[1])
        self.data_heatmap_concat_val = tf.concat(
            [self.data_val, self.single_heatmap_val], axis=1)
        self.prediction_val = training_net(self.data_heatmap_concat_val,
                                           num_labels=self.num_labels,
                                           is_training=False,
                                           actual_network=self.unet,
                                           padding=self.padding,
                                           **self.network_parameters)
        self.prediction_softmax_val = tf.nn.sigmoid(self.prediction_val)

        if self.has_validation_groundtruth:
            self.loss_val = self.loss_function(labels=self.mask_val,
                                               logits=self.prediction_val,
                                               data_format=self.data_format)
            self.val_losses = OrderedDict([('loss', self.loss_val),
                                           ('loss_reg', self.loss_reg),
                                           ('gradient_norm',
                                            tf.constant(0, tf.float32))])

    def test_full_image(self, dataset_entry):
        """
        Perform inference on a dataset_entry with the validation network.
        :param dataset_entry: A dataset entry from the dataset.
        :return: input image (np.array), network prediction (np.array), transformation (sitk.Transform)
        """
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        if self.has_validation_groundtruth:
            feed_dict = {
                self.data_val:
                np.expand_dims(generators['image'], axis=0),
                self.mask_val:
                np.expand_dims(generators['single_label'], axis=0),
                self.single_heatmap_val:
                np.expand_dims(generators['single_heatmap'], axis=0)
            }
            # run loss and update loss accumulators
            run_tuple = self.sess.run(
                (self.prediction_softmax_val, self.loss_val) +
                self.val_loss_aggregator.get_update_ops(),
                feed_dict=feed_dict)
        else:
            feed_dict = {
                self.data_val:
                np.expand_dims(generators['image'], axis=0),
                self.single_heatmap_val:
                np.expand_dims(generators['single_heatmap'], axis=0)
            }
            # run loss and update loss accumulators
            run_tuple = self.sess.run((self.prediction_softmax_val, ),
                                      feed_dict=feed_dict)
        prediction = np.squeeze(run_tuple[0], axis=0)
        transformation = transformations['image']
        image = generators['image']

        return image, prediction, transformation

    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3
        labels = list(range(self.num_labels_all))
        segmentation_test = SegmentationTest(labels,
                                             channel_axis=channel_axis,
                                             largest_connected_component=False,
                                             all_labels_are_connected=False)
        segmentation_statistics = SegmentationStatistics(
            list(range(self.num_labels_all)),
            self.output_folder_for_current_iteration(),
            metrics=OrderedDict([('dice', DiceMetric()),
                                 ('h', HausdorffDistanceMetric())]))
        filter_largest_cc = True

        # iterate over all images
        for i, image_id in enumerate(self.test_id_list):
            first = True
            prediction_resampled_np = None
            input_image = None
            groundtruth = None
            # iterate over all valid landmarks
            for landmark_id in self.valid_landmarks[image_id]:
                dataset_entry = self.dataset_val.get({
                    'image_id': image_id,
                    'landmark_id': landmark_id
                })
                datasources = dataset_entry['datasources']
                if first:
                    input_image = datasources['image']
                    if self.has_validation_groundtruth:
                        groundtruth = datasources['labels']
                    prediction_resampled_np = np.zeros(
                        [self.num_labels_all] +
                        list(reversed(input_image.GetSize())),
                        dtype=np.float16)
                    prediction_resampled_np[0, ...] = 0.5
                    first = False

                image, prediction, transformation = self.test_full_image(
                    dataset_entry)

                if filter_largest_cc:
                    prediction_thresh_np = (prediction > 0.5).astype(np.uint8)
                    largest_connected_component = utils.np_image.largest_connected_component(
                        prediction_thresh_np[0])
                    prediction_thresh_np[largest_connected_component[None, ...]
                                         == 1] = 0
                    prediction[prediction_thresh_np == 1] = 0

                if self.save_output_images:
                    if self.save_output_images_as_uint:
                        image_normalization = 'min_max'
                        label_normalization = (0, 1)
                        output_image_type = np.uint8
                    else:
                        image_normalization = None
                        label_normalization = None
                        output_image_type = np.float32
                    origin = transformation.TransformPoint(
                        np.zeros(3, np.float64))
                    utils.io.image.write_multichannel_np(
                        image,
                        self.output_file_for_current_iteration(image_id + '_' +
                                                               landmark_id +
                                                               '_input.mha'),
                        output_normalization_mode=image_normalization,
                        data_format=self.data_format,
                        image_type=output_image_type,
                        spacing=self.image_spacing,
                        origin=origin)
                    utils.io.image.write_multichannel_np(
                        prediction,
                        self.output_file_for_current_iteration(
                            image_id + '_' + landmark_id + '_prediction.mha'),
                        output_normalization_mode=label_normalization,
                        data_format=self.data_format,
                        image_type=output_image_type,
                        spacing=self.image_spacing,
                        origin=origin)

                prediction_resampled_sitk = utils.sitk_image.transform_np_output_to_sitk_input(
                    output_image=prediction,
                    output_spacing=self.image_spacing,
                    channel_axis=channel_axis,
                    input_image_sitk=input_image,
                    transform=transformation,
                    interpolator='linear',
                    output_pixel_type=sitk.sitkFloat32)
                #utils.io.image.write(prediction_resampled_sitk[0],  self.output_file_for_current_iteration(image_id + '_' + landmark_id + '_resampled.mha'))
                if self.data_format == 'channels_first':
                    prediction_resampled_np[int(landmark_id) + 1,
                                            ...] = utils.sitk_np.sitk_to_np(
                                                prediction_resampled_sitk[0])
                else:
                    prediction_resampled_np[..., int(landmark_id) +
                                            1] = utils.sitk_np.sitk_to_np(
                                                prediction_resampled_sitk[0])
            prediction_labels = segmentation_test.get_label_image(
                prediction_resampled_np, reference_sitk=input_image)
            # delete to save memory
            del prediction_resampled_np
            utils.io.image.write(
                prediction_labels,
                self.output_file_for_current_iteration(image_id + '.mha'))

            if self.has_validation_groundtruth:
                segmentation_statistics.add_labels(image_id, prediction_labels,
                                                   groundtruth)

            print_progress_bar(i,
                               len(self.test_id_list),
                               prefix='Testing ',
                               suffix=' complete')

        # finalize loss values
        if self.has_validation_groundtruth:
            segmentation_statistics.finalize()
            dice_list = segmentation_statistics.get_metric_mean_list('dice')
            mean_dice = np.nanmean(dice_list)
            dice_list = [mean_dice] + dice_list
            hausdorff_list = segmentation_statistics.get_metric_mean_list('h')
            mean_hausdorff = np.mean(hausdorff_list)
            hausdorff_list = [mean_hausdorff] + hausdorff_list
            summary_values = OrderedDict(
                list(zip(self.dice_names, dice_list)) +
                list(zip(self.hausdorff_names, hausdorff_list)))
            self.val_loss_aggregator.finalize(self.current_iter,
                                              summary_values=summary_values)
Ejemplo n.º 17
0
class MainLoop(MainLoopBase):
    def __init__(self, cv, network_id):
        super().__init__()
        self.cv = cv
        self.network_id = network_id
        self.output_folder = '{}_cv{}'.format(
            network_id, cv) + '/' + self.output_folder_timestamp()
        self.batch_size = 8
        self.learning_rate = 0.0000001
        self.max_iter = 20000
        self.test_iter = 5000
        self.disp_iter = 100
        self.snapshot_iter = self.test_iter
        self.test_initialization = True
        self.current_iter = 0
        self.reg_constant = 0.00005
        self.invert_transformation = False
        image_sizes = {
            'scn': [256, 256],
            'unet': [256, 256],
            'downsampling': [256, 256],
            'conv': [128, 128],
            'scn_mmwhs': [256, 256]
        }
        heatmap_sizes = {
            'scn': [256, 256],
            'unet': [256, 256],
            'downsampling': [64, 64],
            'conv': [128, 128],
            'scn_mmwhs': [256, 256]
        }
        sigmas = {
            'scn': 3.0,
            'unet': 3.0,
            'downsampling': 1.5,
            'conv': 3.0,
            'scn_mmwhs': 3.0
        }
        self.image_size = image_sizes[self.network_id]
        self.heatmap_size = heatmap_sizes[self.network_id]
        self.sigma = sigmas[self.network_id]
        self.image_channels = 1
        self.num_landmarks = 37
        self.data_format = 'channels_first'
        self.save_debug_images = False
        self.base_folder = 'hand_xray_dataset/'
        dataset = Dataset(self.image_size, self.heatmap_size,
                          self.num_landmarks, self.sigma, self.base_folder,
                          self.cv, self.data_format, self.save_debug_images)
        self.dataset_train = dataset.dataset_train()
        self.dataset_train.get_next()
        self.dataset_val = dataset.dataset_val()

        networks = {
            'scn': network_scn,
            'unet': network_unet,
            'downsampling': network_downsampling,
            'conv': network_conv,
            'scn_mmwhs': network_scn_mmwhs
        }
        self.network = networks[self.network_id]
        self.loss_function = lambda x, y: tf.nn.l2_loss(
            x - y) / get_batch_channel_image_size(x, self.data_format)[0]
        self.files_to_copy = ['main.py', 'network.py', 'dataset.py']

    def initNetworks(self):
        net = tf.make_template('net', self.network)

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image',
                 [self.image_channels] + list(reversed(self.image_size))),
                ('landmarks',
                 [self.num_landmarks] + list(reversed(self.heatmap_size)))
            ])
            data_generator_entries_val = OrderedDict([
                ('image',
                 [self.image_channels] + list(reversed(self.image_size))),
                ('landmarks',
                 [self.num_landmarks] + list(reversed(self.heatmap_size)))
            ])
        else:
            raise NotImplementedError

        self.train_queue = DataGenerator(self.dataset_train,
                                         self.coord,
                                         data_generator_entries,
                                         batch_size=self.batch_size,
                                         n_threads=8)
        placeholders = self.train_queue.dequeue()
        image = placeholders[0]
        landmarks = placeholders[1]
        prediction = net(image,
                         num_landmarks=self.num_landmarks,
                         is_training=True,
                         data_format=self.data_format)
        self.loss_net = self.loss_function(landmarks, prediction)

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            if self.reg_constant > 0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                self.loss_reg = self.reg_constant * tf.add_n(reg_losses)
                self.loss = self.loss_net + self.loss_reg
            else:
                self.loss_reg = 0
                self.loss = self.loss_net

        self.train_losses = OrderedDict([('loss', self.loss_net),
                                         ('loss_reg', self.loss_reg)])
        #self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss)
        self.optimizer = tf.train.MomentumOptimizer(
            learning_rate=self.learning_rate, momentum=0.99,
            use_nesterov=True).minimize(self.loss)

        # build val graph
        self.val_placeholders = tensorflow_train.utils.tensorflow_util.create_placeholders(
            data_generator_entries_val, shape_prefix=[1])
        self.image_val = self.val_placeholders['image']
        self.landmarks_val = self.val_placeholders['landmarks']
        self.prediction_val = net(self.image_val,
                                  num_landmarks=self.num_landmarks,
                                  is_training=False,
                                  data_format=self.data_format)

        # losses
        self.loss_val = self.loss_function(self.landmarks_val,
                                           self.prediction_val)
        self.val_losses = OrderedDict([('loss', self.loss_val),
                                       ('loss_reg', self.loss_reg)])

    def test_full_image(self, dataset_entry):
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        feed_dict = {
            self.val_placeholders['image']:
            np.expand_dims(generators['image'], axis=0),
            self.val_placeholders['landmarks']:
            np.expand_dims(generators['landmarks'], axis=0)
        }

        # run loss and update loss accumulators
        run_tuple = self.sess.run((self.prediction_val, self.loss_val) +
                                  self.val_loss_aggregator.get_update_ops(),
                                  feed_dict)
        prediction = np.squeeze(run_tuple[0], axis=0)
        image = generators['image']
        transformation = transformations['image']

        return image, prediction, transformation

    def test(self):
        heatmap_test = HeatmapTest(channel_axis=0, invert_transformation=False)
        landmark_statistics = LandmarkStatistics()

        landmarks = {}
        for i in range(self.dataset_val.num_entries()):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            reference_image = datasources['image_datasource']
            groundtruth_landmarks = datasources['landmarks_datasource']
            image, prediction, transform = self.test_full_image(dataset_entry)

            utils.io.image.write_np(
                (prediction * 128).astype(np.int8),
                self.output_file_for_current_iteration(current_id +
                                                       '_heatmap.mha'))
            predicted_landmarks = heatmap_test.get_landmarks(
                prediction, reference_image, transformation=transform)
            tensorflow_train.utils.tensorflow_util.print_progress_bar(
                i, self.dataset_val.num_entries())
            landmarks[current_id] = predicted_landmarks
            landmark_statistics.add_landmarks(current_id,
                                              predicted_landmarks,
                                              groundtruth_landmarks,
                                              normalization_factor=50,
                                              normalization_indizes=[1, 5])

        tensorflow_train.utils.tensorflow_util.print_progress_bar(
            self.dataset_val.num_entries(), self.dataset_val.num_entries())
        print('ipe', landmark_statistics.get_ipe_statistics())
        print('pe', landmark_statistics.get_pe_statistics())
        print('outliers',
              landmark_statistics.get_num_outliers([2.0, 4.0, 10.0]))

        # finalize loss values
        self.val_loss_aggregator.finalize(self.current_iter)
        utils.io.landmark.save_points_csv(
            landmarks,
            self.output_file_for_current_iteration('prediction.csv'))