class MainLoop(MainLoopBase):
    def __init__(self, config):
        """
        Initializer.
        :param cv: The cv fold. 0, 1, 2 for CV; 'train_all' for training on whole dataset.
        :param config: config dictionary
        """
        super().__init__()
        gpu_available = tf.test.gpu_device_name() != ''
        self.use_mixed_precision = gpu_available
        if self.use_mixed_precision:
            policy = mixed_precision.Policy('mixed_float16')
            mixed_precision.set_policy(policy)
        self.cv = config.cv
        self.config = config
        self.batch_size = 1
        self.num_labels = 1
        self.num_labels_all = 27
        self.data_format = 'channels_last'
        self.network_parameters = OrderedDict(
            num_filters_base=config.num_filters_base,
            activation=config.activation,
            num_levels=config.num_levels,
            data_format=self.data_format)
        self.network = Unet
        self.save_output_images = False
        self.save_debug_images = False
        self.image_folder = config.image_folder
        self.setup_folder = config.setup_folder
        self.output_folder = config.output_folder
        self.load_model_filenames = config.load_model_filenames
        self.image_size = [128, 128, 96]
        self.image_spacing = [config.spacing] * 3
        self.heatmap_size = self.image_size

        images_files = sorted(glob(os.path.join(self.image_folder,
                                                '*.nii.gz')))
        self.image_id_list = list(
            map(lambda filename: os.path.basename(filename)[:-len('.nii.gz')],
                images_files))
        self.valid_landmarks_file = os.path.join(
            self.setup_folder, 'vertebrae_localization/valid_landmarks.csv')
        self.valid_landmarks = utils.io.text.load_dict_csv(
            self.valid_landmarks_file)

        self.landmark_labels = [i + 1 for i in range(25)] + [28]
        self.landmark_mapping = dict([(i, self.landmark_labels[i])
                                      for i in range(26)])
        self.landmark_mapping_inverse = dict([(self.landmark_labels[i], i)
                                              for i in range(26)])

        #if self.data_format == 'channels_first':
        #self.call_model = tf.function(self.call_model, input_signature=[tf.TensorSpec(tf.TensorShape([1, 2] + list(reversed(self.image_size))), tf.float16 if self.use_mixed_precision else tf.float32)])
        #else:
        #self.call_model = tf.function(self.call_model, input_signature=[tf.TensorSpec(tf.TensorShape([1] + list(reversed(self.image_size))) + [2], tf.float16 if self.use_mixed_precision else tf.float32)])

    def init_model(self):
        # create sigmas variable
        self.model = self.network(num_labels=self.num_labels,
                                  **self.network_parameters)

    def init_checkpoint(self):
        self.checkpoint = tf.train.Checkpoint(model=self.model)

    def init_output_folder_handler(self):
        self.output_folder_handler = OutputFolderHandler(self.output_folder,
                                                         use_timestamp=False,
                                                         files_to_copy=[])

    def init_datasets(self):
        dataset_parameters = dict(image_base_folder=self.image_folder,
                                  setup_base_folder=self.setup_folder,
                                  image_size=self.image_size,
                                  image_spacing=self.image_spacing,
                                  normalize_zero_mean_unit_variance=False,
                                  cv=self.cv,
                                  input_gaussian_sigma=0.75,
                                  heatmap_sigma=3.0,
                                  generate_single_vertebrae_heatmap=True,
                                  output_image_type=np.float16
                                  if self.use_mixed_precision else np.float32,
                                  data_format=self.data_format,
                                  save_debug_images=self.save_debug_images)

        dataset = Dataset(**dataset_parameters)
        self.dataset_val = dataset.dataset_val()
        self.network_image_size = list(reversed(self.image_size))

    def call_model(self, image):
        return self.model(image, training=False)

    def test_full_image(self, dataset_entry):
        """
        Perform inference on a dataset_entry with the validation network.
        :param dataset_entry: A dataset entry from the dataset.
        :return: input image (np.array), network prediction (np.array), transformation (sitk.Transform)
        """
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        image = np.expand_dims(generators['image'], axis=0)
        single_heatmap = np.expand_dims(generators['single_heatmap'], axis=0)
        image_heatmap_concat = tf.concat(
            [image, single_heatmap],
            axis=1 if self.data_format == 'channels_first' else -1)
        predictions = []
        for load_model_filename in self.load_model_filenames:
            if len(self.load_model_filenames) > 1:
                self.load_model(load_model_filename)
            prediction = tf.sigmoid(self.call_model(image_heatmap_concat))
            predictions.append(prediction.numpy())
        prediction = np.mean(predictions, axis=0)
        prediction = np.squeeze(prediction, axis=0)
        transformation = transformations['image']
        image = generators['image']

        return image, prediction, transformation

    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')

        if len(self.load_model_filenames) == 1:
            self.load_model(self.load_model_filenames[0])

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3

        filter_largest_cc = True

        # iterate over all images
        for image_id in tqdm(self.image_id_list, desc='Testing'):
            #try:
            first = True
            prediction_labels_np = None
            prediction_max_value_np = None
            input_image = None
            # iterate over all valid landmarks
            for landmark_id in self.valid_landmarks[image_id]:
                dataset_entry = self.dataset_val.get({
                    'image_id': image_id,
                    'landmark_id': landmark_id
                })
                if first:
                    input_image = dataset_entry['datasources']['image']
                    prediction_labels_np = np.zeros(list(
                        reversed(input_image.GetSize())),
                                                    dtype=np.uint8)
                    prediction_max_value_np = np.ones(list(
                        reversed(input_image.GetSize())),
                                                      dtype=np.float32) * 0.5
                    first = False

                image, prediction, transformation = self.test_full_image(
                    dataset_entry)
                del dataset_entry

                origin = transformation.TransformPoint(np.zeros(3, np.float64))
                max_index = transformation.TransformPoint(
                    np.array(self.image_size, np.float64) *
                    np.array(self.image_spacing, np.float64))

                if self.save_output_images:
                    utils.io.image.write_multichannel_np(
                        image,
                        self.output_folder_handler.path(
                            'output',
                            image_id + '_' + landmark_id + '_input.mha'),
                        output_normalization_mode='min_max',
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=np.uint8,
                        spacing=self.image_spacing,
                        origin=origin)
                    utils.io.image.write_multichannel_np(
                        prediction,
                        self.output_folder_handler.path(
                            'output',
                            image_id + '_' + landmark_id + '_prediction.mha'),
                        output_normalization_mode=(0, 1),
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=np.uint8,
                        spacing=self.image_spacing,
                        origin=origin)
                del image
                prediction = prediction.astype(np.float32)
                prediction_resampled_sitk = utils.sitk_image.transform_np_output_to_sitk_input(
                    output_image=prediction,
                    output_spacing=self.image_spacing,
                    channel_axis=channel_axis,
                    input_image_sitk=input_image,
                    transform=transformation,
                    interpolator='cubic',
                    output_pixel_type=sitk.sitkFloat32)
                del prediction
                #del transformation
                prediction_resampled_np = utils.sitk_np.sitk_to_np(
                    prediction_resampled_sitk[0])
                if self.save_output_images:
                    utils.io.image.write_multichannel_np(
                        prediction_resampled_np,
                        self.output_folder_handler.path(
                            'output', image_id + '_' + landmark_id +
                            '_prediction_resampled.mha'),
                        output_normalization_mode=(0, 1),
                        is_single_channel=True,
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=np.uint8,
                        spacing=prediction_resampled_sitk[0].GetSpacing(),
                        origin=prediction_resampled_sitk[0].GetOrigin())
                bb_start = np.floor(
                    np.flip(origin / np.array(input_image.GetSpacing())))
                bb_start = np.maximum(bb_start, [0, 0, 0])
                bb_end = np.ceil(
                    np.flip(max_index / np.array(input_image.GetSpacing())))
                bb_end = np.minimum(
                    bb_end, prediction_resampled_np.shape - np.ones(3)
                )  # bb is inclusive -> subtract [1, 1, 1] from max size
                #print(bb_start, bb_end)
                #bb_start, bb_end = utils.np_image.bounding_box(prediction_resampled_np)
                slices = tuple([
                    slice(int(bb_start[i]), int(bb_end[i] + 1))
                    for i in range(3)
                ])
                prediction_resampled_cropped_np = prediction_resampled_np[
                    slices]
                if filter_largest_cc:
                    prediction_thresh_cropped_np = (
                        prediction_resampled_cropped_np > 0.5).astype(np.uint8)
                    largest_connected_component = utils.np_image.largest_connected_component(
                        prediction_thresh_cropped_np)
                    prediction_thresh_cropped_np[largest_connected_component ==
                                                 1] = 0
                    prediction_resampled_cropped_np[
                        prediction_thresh_cropped_np == 1] = 0
                prediction_max_value_cropped_np = prediction_max_value_np[
                    slices]
                prediction_labels_cropped_np = prediction_labels_np[slices]
                prediction_max_index_np = utils.np_image.argmax(np.stack(
                    [
                        prediction_max_value_cropped_np,
                        prediction_resampled_cropped_np
                    ],
                    axis=-1),
                                                                axis=-1)
                prediction_max_index_new_np = prediction_max_index_np == 1
                prediction_max_value_cropped_np[
                    prediction_max_index_new_np] = prediction_resampled_cropped_np[
                        prediction_max_index_new_np]
                prediction_labels_cropped_np[
                    prediction_max_index_new_np] = self.landmark_mapping[int(
                        landmark_id)]
                prediction_max_value_np[
                    slices] = prediction_max_value_cropped_np
                prediction_labels_np[slices] = prediction_labels_cropped_np
                del prediction_resampled_sitk

            # delete to save memory
            del prediction_max_value_np
            prediction_labels = utils.sitk_np.np_to_sitk(prediction_labels_np)
            prediction_labels.CopyInformation(input_image)
            del prediction_labels_np
            utils.io.image.write(
                prediction_labels,
                self.output_folder_handler.path(image_id + '_seg.nii.gz'))
            if self.save_output_images:
                prediction_labels_resampled = utils.sitk_np.sitk_to_np(
                    utils.sitk_image.resample_to_spacing(
                        prediction_labels, [1.0, 1.0, 1.0], 'nearest'))
                prediction_labels_resampled = np.flip(
                    prediction_labels_resampled, axis=0)
                utils.io.image.write_multichannel_np(
                    prediction_labels_resampled,
                    self.output_folder_handler.path('output',
                                                    image_id + '_seg.png'),
                    channel_layout_mode='label_rgb',
                    output_normalization_mode=(0, 1),
                    image_layout_mode='max_projection',
                    is_single_channel=True,
                    sitk_image_output_mode='vector',
                    data_format=self.data_format,
                    image_type=np.uint8)
                utils.io.image.write_multichannel_np(
                    prediction_labels_resampled,
                    self.output_folder_handler.path('output',
                                                    image_id + '_seg_rgb.mha'),
                    channel_layout_mode='label_rgb',
                    output_normalization_mode=(0, 1),
                    is_single_channel=True,
                    sitk_image_output_mode='vector',
                    data_format=self.data_format,
                    image_type=np.uint8)
                input_resampled = utils.sitk_np.sitk_to_np(
                    utils.sitk_image.resample_to_spacing(
                        input_image, [1.0, 1.0, 1.0], 'linear'))
                input_resampled = np.flip(input_resampled, axis=0)
                utils.io.image.write_multichannel_np(
                    input_resampled,
                    self.output_folder_handler.path('output',
                                                    image_id + '_input.png'),
                    output_normalization_mode='min_max',
                    image_layout_mode='max_projection',
                    is_single_channel=True,
                    sitk_image_output_mode='vector',
                    data_format=self.data_format,
                    image_type=np.uint8)

            del prediction_labels
Пример #2
0
class MainLoop(MainLoopBase):
    def __init__(self, config):
        """
        Initializer.
        :param cv: The cv fold. 0, 1, 2 for CV; 'train_all' for training on whole dataset.
        :param config: config dictionary
        """
        super().__init__()
        gpu_available = tf.test.gpu_device_name() != ''
        self.use_mixed_precision = gpu_available
        if self.use_mixed_precision:
            policy = mixed_precision.Policy('mixed_float16')
            mixed_precision.set_policy(policy)
        self.cv = config.cv
        self.config = config
        self.data_format = 'channels_last'
        self.network_parameters = OrderedDict(num_filters_base=config.num_filters_base,
                                              activation=config.activation,
                                              num_levels=config.num_levels,
                                              data_format=self.data_format)
        if config.model == 'unet':
            self.network = Unet

        self.save_output_images = True
        self.save_debug_images = False
        self.image_folder = config.image_folder
        self.setup_folder = config.setup_folder
        self.output_folder = config.output_folder
        self.load_model_filenames = config.load_model_filenames
        self.image_size = [None, None, None]
        self.image_spacing = [config.spacing] * 3
        images_files = sorted(glob(os.path.join(self.image_folder, '*.nii.gz')))
        self.image_id_list = list(map(lambda filename: os.path.basename(filename)[:-len('.nii.gz')], images_files))

    def init_model(self):
        """
        Init self.model.
        """
        self.model = self.network(num_labels=1, **self.network_parameters)

    def init_checkpoint(self):
        """
        Init self.checkpoint.
        """
        self.checkpoint = tf.train.Checkpoint(model=self.model)

    def init_output_folder_handler(self):
        """
        Init self.output_folder_handler.
        """
        self.output_folder_handler = OutputFolderHandler(self.output_folder, use_timestamp=False, files_to_copy=[])

    def init_datasets(self):
        """
        Init self.dataset_val.
        """
        dataset_parameters = dict(image_base_folder=self.image_folder,
                                  setup_base_folder=self.setup_folder,
                                  image_size=self.image_size,
                                  image_spacing=self.image_spacing,
                                  normalize_zero_mean_unit_variance=False,
                                  valid_output_sizes_x=[32, 64, 96, 128],
                                  valid_output_sizes_y=[32, 64, 96, 128],
                                  valid_output_sizes_z=[32, 64, 96, 128],
                                  use_variable_image_size=True,
                                  cv=self.cv,
                                  input_gaussian_sigma=0.75,
                                  output_image_type=np.float16 if self.use_mixed_precision else np.float32,
                                  data_format=self.data_format,
                                  save_debug_images=self.save_debug_images)

        dataset = Dataset(**dataset_parameters)
        self.dataset_val = dataset.dataset_val()
        self.network_image_size = list(reversed(self.image_size))

    def call_model(self, image):
        """
        Call model.
        :param image: The image to call the model with.
        :return prediction
        """
        return self.model(image, training=False)

    def test_full_image(self, dataset_entry):
        """
        Perform inference on a dataset_entry with the validation network.
        :param dataset_entry: A dataset entry from the dataset.
        :return: input image (np.array), network prediction (np.array), transformation (sitk.Transform)
        """
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        image = np.expand_dims(generators['image'], axis=0)
        predictions = []
        for load_model_filename in self.load_model_filenames:
            if len(self.load_model_filenames) > 1:
                self.load_model(load_model_filename)
            prediction = self.call_model(image)
            predictions.append(prediction.numpy())
        prediction = np.mean(predictions, axis=0)
        prediction = np.squeeze(prediction, axis=0)
        transformation = transformations['image']
        image = generators['image']

        return image, prediction, transformation

    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')

        if len(self.load_model_filenames) == 1:
            self.load_model(self.load_model_filenames[0])

        bbs = {}
        for current_id in tqdm(self.image_id_list, desc='Testing'):
            try:
                dataset_entry = self.dataset_val.get({'image_id': current_id})
                image, prediction, transformation = self.test_full_image(dataset_entry)
                start, end = bb(prediction, transformation, self.image_spacing)
                bbs[current_id] = start + end

                if self.save_output_images:
                    origin = np.array(transformation.TransformPoint(np.zeros(3, np.float64)))
                    utils.io.image.write_multichannel_np(image,
                                                         self.output_folder_handler.path('output', current_id + '_input.mha'),
                                                         output_normalization_mode='min_max',
                                                         data_format=self.data_format,
                                                         image_type=np.uint8,
                                                         spacing=self.image_spacing,
                                                         origin=origin)
                    utils.io.image.write_multichannel_np(prediction,
                                                         self.output_folder_handler.path('output', current_id + '_prediction.mha'),
                                                         output_normalization_mode=(0, 1),
                                                         data_format=self.data_format,
                                                         image_type=np.uint8,
                                                         spacing=self.image_spacing,
                                                         origin=origin)
            except Exception:
                print('ERROR predicting', current_id)
                traceback.print_exc(file=sys.stdout)
                pass

        utils.io.text.save_dict_csv(bbs, self.output_folder_handler.path('bbs.csv'))
Пример #3
0
class MainLoop(MainLoopBase):
    def __init__(self, cv, config):
        """
        Initializer.
        :param cv: The cv fold. 0, 1, 2 for CV; 'train_all' for training on whole dataset.
        :param config: config dictionary
        """
        super().__init__()
        self.use_mixed_precision = True
        if self.use_mixed_precision:
            policy = mixed_precision.Policy('mixed_float16')
            mixed_precision.set_policy(policy)
        self.cv = cv
        self.config = config
        self.batch_size = 1
        self.learning_rate = config.learning_rate
        self.learning_rates = [self.learning_rate, self.learning_rate * 0.5, self.learning_rate * 0.1]
        self.learning_rate_boundaries = [50000, 75000]
        self.max_iter = 10000
        self.test_iter = 5000
        self.disp_iter = 100
        self.snapshot_iter = 5000
        self.test_initialization = False
        self.reg_constant = 0.0
        self.data_format = 'channels_first'
        self.network_parameters = OrderedDict(num_filters_base=config.num_filters_base,
                                              activation=config.activation,
                                              dropout_ratio=config.dropout_ratio,
                                              num_levels=config.num_levels,
                                              heatmap_initialization=True,
                                              data_format=self.data_format)
        if config.model == 'unet':
            self.network = Unet
        self.clip_gradient_global_norm = 100000.0

        self.use_pyro_dataset = True
        self.save_output_images = True
        self.save_debug_images = False
        self.has_validation_groundtruth = cv in [0, 1, 2]
        self.local_base_folder = '../verse2020_dataset'
        self.image_size = [None, None, None]
        self.image_spacing = [config.spacing] * 3
        self.sigma_regularization = 100.0
        self.sigma_scale = 1000.0
        self.cropped_training = True
        self.base_output_folder = './output/spine_localization/'
        self.additional_output_folder_info = config.info

        self.call_model_and_loss = tf.function(self.call_model_and_loss,
                                               input_signature=[tf.TensorSpec(tf.TensorShape([1, 1] + list(reversed(self.image_size))), tf.float16 if self.use_mixed_precision else tf.float32),
                                                                tf.TensorSpec(tf.TensorShape([1, 1] + list(reversed(self.image_size))), tf.float32),
                                                                tf.TensorSpec(tf.TensorShape(None), tf.bool)])

    def init_model(self):
        """
        Init self.model.
        """
        self.norm_moving_average = tf.Variable(10.0)
        self.model = self.network(num_labels=1, **self.network_parameters)

    def init_optimizer(self):
        """
        Init self.optimizer.
        """
        self.learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(self.learning_rate, self.max_iter, 0.1)
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
        if self.use_mixed_precision:
            self.optimizer = mixed_precision.LossScaleOptimizer(self.optimizer,
                                                                loss_scale=tf.mixed_precision.experimental.DynamicLossScale(initial_loss_scale=2 ** 15, increment_period=1000))

    def init_checkpoint(self):
        """
        Init self.checkpoint.
        """
        self.checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer)

    def init_output_folder_handler(self):
        """
        Init self.output_folder_handler.
        """
        self.output_folder_handler = OutputFolderHandler(self.base_output_folder,
                                                         model_name=self.model.name,
                                                         cv=str(self.cv),
                                                         additional_info=self.additional_output_folder_info)

    def init_datasets(self):
        """
        Init self.dataset_train, self.dataset_train_iter, self.dataset_val.
        """
        dataset_parameters = dict(base_folder=self.local_base_folder,
                                  image_size=self.image_size,
                                  image_spacing=self.image_spacing,
                                  normalize_zero_mean_unit_variance=False,
                                  cv=self.cv,
                                  heatmap_sigma=3.0,
                                  generate_spine_heatmap=True,
                                  use_variable_image_size=True,
                                  valid_output_sizes_x=[32, 64, 96, 128],
                                  valid_output_sizes_y=[32, 64, 96, 128],
                                  valid_output_sizes_z=[32, 64, 96, 128],
                                  output_image_type=np.float16 if self.use_mixed_precision else np.float32,
                                  data_format=self.data_format,
                                  save_debug_images=self.save_debug_images)

        dataset = Dataset(**dataset_parameters)
        if self.use_pyro_dataset:
            # TODO: adapt hostname, in case this script runs on a remote server
            hostname = socket.gethostname()
            server_name = '@' + hostname + ':52132'
            uri = 'PYRO:verse2020_dataset' + server_name
            print('using pyro uri', uri)
            try:
                self.dataset_train = PyroClientDataset(uri, **dataset_parameters)
            except Exception as e:
                print('Error connecting to server dataset. Start server_dataset_loop.py and set correct hostname, or set self.use_pyro_dataset = False.')
                raise e
        else:
            self.dataset_train = dataset.dataset_train()

        self.dataset_val = dataset.dataset_val()
        self.network_image_size = list(reversed(self.image_size))

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([('image', [1] + self.network_image_size),
                                                  ('spine_heatmap', [1] + self.network_image_size),
                                                  ('image_id', tuple())])
        else:
            data_generator_entries = OrderedDict([('image', self.network_image_size + [1]),
                                                  ('spine_heatmap', self.network_image_size + [1]),
                                                  ('image_id', tuple())])

        data_generator_types = {'image': tf.float16 if self.use_mixed_precision else tf.float32, 'spine_heatmap': tf.float32, 'image_id': tf.string}
        self.dataset_train_iter = DatasetIterator(dataset=self.dataset_train,
                                                  data_names_and_shapes=data_generator_entries,
                                                  data_types=data_generator_types,
                                                  batch_size=self.batch_size)

    def init_loggers(self):
        """
        Init self.loss_metric_logger_train, self.loss_metric_logger_val.
        """
        self.loss_metric_logger_train = LossMetricLogger('train',
                                                         self.output_folder_handler.path('train'),
                                                         self.output_folder_handler.path('train.csv'))
        self.loss_metric_logger_val = LossMetricLogger('test',
                                                       self.output_folder_handler.path('test'),
                                                       self.output_folder_handler.path('test.csv'))

    @tf.function
    def loss_function(self, pred, target, mask=None):
        """
        L2 loss function calculated with prediction and target.
        :param pred: The predicted image.
        :param target: The target image.
        :param mask: If not none, calculate loss only pixels, where mask == 1
        :return: L2 loss of (pred - target) / batch_size
        """
        batch_size, channel_size, image_size = get_batch_channel_image_size(pred, self.data_format, as_tensor=True)
        if mask is not None:
            diff = (pred - target) * mask
        else:
            diff = pred - target
        return tf.nn.l2_loss(diff) / tf.cast(batch_size * 1024, tf.float32) #* channel_size * np.prod(image_size))

    def call_model_and_loss(self, image, target_heatmap, training):
        """
        Call model and loss.
        :param image: The image to call the model with.
        :param target_heatmap: The target heatmap used for loss calculation.
        :param training: training parameter used for calling the model.
        :return (prediction, losses) tuple
        """
        prediction = self.model(image, training=training)
        losses = {}
        losses['loss_net'] = self.loss_function(target=target_heatmap, pred=prediction)
        return prediction, losses

    @tf.function
    def train_step(self):
        """
        Perform a training step.
        """
        image, target_landmarks, image_id = self.dataset_train_iter.get_next()
        with tf.GradientTape() as tape:
            _, losses = self.call_model_and_loss(image, target_landmarks, training=True)
            if self.reg_constant > 0:
                losses['loss_reg'] = self.reg_constant * tf.reduce_sum(self.model.losses)
            loss = tf.reduce_sum(list(losses.values()))
            if self.use_mixed_precision:
                scaled_loss = self.optimizer.get_scaled_loss(loss)
        variables = self.model.trainable_weights
        metric_dict = losses
        clip_norm = self.norm_moving_average * 5
        if self.use_mixed_precision:
            scaled_grads = tape.gradient(scaled_loss, variables)
            grads = self.optimizer.get_unscaled_gradients(scaled_grads)
            grads, norm = tf.clip_by_global_norm(grads, clip_norm)
            loss_scale = self.optimizer.loss_scale()
            metric_dict.update({'loss_scale': loss_scale})
        else:
            grads = tape.gradient(loss, variables)
            grads, norm = tf.clip_by_global_norm(grads, clip_norm)
        if tf.math.is_finite(norm):
            alpha = 0.01
            self.norm_moving_average.assign(alpha * tf.minimum(norm, clip_norm) + (1 - alpha) * self.norm_moving_average)
        metric_dict.update({'norm': norm, 'norm_average': self.norm_moving_average})
        self.optimizer.apply_gradients(zip(grads, variables))

        self.loss_metric_logger_train.update_metrics(metric_dict)

    def test_full_image(self, dataset_entry):
        """
        Perform inference on a dataset_entry with the validation network.
        :param dataset_entry: A dataset entry from the dataset.
        :return: input image (np.array), network prediction (np.array), transformation (sitk.Transform)
        """
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        image = np.expand_dims(generators['image'], axis=0)
        if self.has_validation_groundtruth:
            spine_heatmap = np.expand_dims(generators['spine_heatmap'], axis=0)
            prediction, losses = self.call_model_and_loss(image, spine_heatmap, training=False)
            self.loss_metric_logger_val.update_metrics(losses)
        else:
            prediction = self.model(image, training=False)
        prediction = np.squeeze(prediction, axis=0)
        transformation = transformations['image']
        image = generators['image']

        return image, prediction, transformation

    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')

        num_entries = self.dataset_val.num_entries()
        ious = {}
        for _ in tqdm(range(num_entries), desc='Testing'):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            print(current_id)
            image, prediction, transformation = self.test_full_image(dataset_entry)
            start_transformed, end_transformed = bb(prediction, transformation, self.image_spacing)
            if self.has_validation_groundtruth:
                groundtruth = dataset_entry['generators']['spine_heatmap']
                gt_start_transformed, gt_end_transformed = bb(groundtruth, transformation, self.image_spacing)
                iou = bb_iou((start_transformed, end_transformed), (gt_start_transformed, gt_end_transformed))
                ious[current_id] = iou

            if self.save_output_images:
                origin = transformation.TransformPoint(np.zeros(3, np.float64))
                utils.io.image.write_multichannel_np(image,
                                                     self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_input.mha'),
                                                     output_normalization_mode='min_max',
                                                     sitk_image_output_mode='vector',
                                                     data_format=self.data_format,
                                                     image_type=np.uint8,
                                                     spacing=self.image_spacing,
                                                     origin=origin)
                utils.io.image.write_multichannel_np(prediction,
                                                     self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_prediction.mha'),
                                                     output_normalization_mode=(0, 1),
                                                     sitk_image_output_mode='vector',
                                                     data_format=self.data_format,
                                                     image_type=np.uint8,
                                                     spacing=self.image_spacing,
                                                     origin=origin)

        # finalize loss values
        if self.has_validation_groundtruth:
            mean_iou = np.mean(list(ious.values()))
            self.loss_metric_logger_val.update_metrics({'mean_iou': mean_iou})

        self.loss_metric_logger_val.finalize(self.current_iter)
class MainLoop(MainLoopBase):
    def __init__(self, cv, config):
        """
        Initializer.
        :param cv: The cv fold. 0, 1, 2 for CV; 'train_all' for training on whole dataset.
        :param config: config dictionary
        """
        super().__init__()
        self.use_mixed_precision = True
        if self.use_mixed_precision:
            policy = mixed_precision.Policy('mixed_float16')
            mixed_precision.set_policy(policy)
        self.cv = cv
        self.config = config
        self.batch_size = 1
        self.learning_rate = config.learning_rate
        self.learning_rates = [self.learning_rate, self.learning_rate * 0.5, self.learning_rate * 0.1]
        self.learning_rate_boundaries = [50000, 75000]
        self.has_validation_groundtruth = cv in [0, 1, 2]
        self.max_iter = 50000
        self.test_iter = 5000 if self.has_validation_groundtruth else self.max_iter
        self.disp_iter = 100
        self.snapshot_iter = 5000
        self.test_initialization = False
        self.reg_constant = 0.0
        self.use_background = True
        self.num_landmarks = 26
        self.heatmap_sigma = config.heatmap_sigma
        self.learnable_sigma = config.learnable_sigma
        self.data_format = 'channels_first'
        self.network_parameters = OrderedDict(num_filters_base=config.num_filters_base,
                                              activation=config.activation,
                                              spatial_downsample=config.spatial_downsample,
                                              dropout_ratio=config.dropout_ratio,
                                              local_activation=config.local_activation,
                                              spatial_activation=config.spatial_activation,
                                              num_levels=config.num_levels,
                                              data_format=self.data_format)
        if config.model == 'scn':
            self.network = SpatialConfigurationNet
        if config.model == 'unet':
            self.network = Unet
        self.clip_gradient_global_norm = 10000.0

        self.evaluate_landmarks_postprocessing = True
        self.use_pyro_dataset = True
        self.save_output_images = True
        self.save_debug_images = False
        self.local_base_folder = '../verse2020_dataset'
        self.image_size = [None, None, None]
        self.image_spacing = [config.spacing] * 3
        self.max_image_size_for_cropped_test = [128, 128, 448]
        self.cropped_inc = [0, 128, 0, 0]
        self.heatmap_size = self.image_size
        self.sigma_regularization = 100.0
        self.sigma_scale = 1000.0
        self.cropped_training = True
        self.base_output_folder = './output/vertebrae_localization/'
        self.additional_output_folder_info = config.info

        if self.data_format == 'channels_first':
            self.call_model_and_loss = tf.function(self.call_model_and_loss,
                                                   input_signature=[tf.TensorSpec(tf.TensorShape([1, 1] + list(reversed(self.image_size))), tf.float16 if self.use_mixed_precision else tf.float32),
                                                                    tf.TensorSpec(tf.TensorShape([1, self.num_landmarks, 4]), tf.float32),
                                                                    tf.TensorSpec(tf.TensorShape(None), tf.bool)])
        else:
            self.call_model_and_loss = tf.function(self.call_model_and_loss,
                                                   input_signature=[tf.TensorSpec(tf.TensorShape([1] + list(reversed(self.image_size))) + [1], tf.float16 if self.use_mixed_precision else tf.float32),
                                                                    tf.TensorSpec(tf.TensorShape([1, self.num_landmarks, 4]), tf.float32),
                                                                    tf.TensorSpec(tf.TensorShape(None), tf.bool)])

    def init_model(self):
        """
        Init self.model.
        """
        self.norm_moving_average = tf.Variable(10.0)
        # create sigmas variable
        self.sigmas_variables = tf.Variable([self.heatmap_sigma] * self.num_landmarks, name='sigmas', trainable=True)
        self.sigmas = self.sigmas_variables
        if not self.learnable_sigma:
            self.sigmas = tf.stop_gradient(self.sigmas)
        self.model = self.network(num_labels=self.num_landmarks, **self.network_parameters)

    def init_optimizer(self):
        """
        Init self.optimizer.
        """
        self.learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(self.learning_rate, self.max_iter, 0.1)
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate, amsgrad=False)
        if self.use_mixed_precision:
            self.optimizer = mixed_precision.LossScaleOptimizer(self.optimizer,
                                                                loss_scale=tf.mixed_precision.experimental.DynamicLossScale(initial_loss_scale=8, increment_period=1000))

    def init_checkpoint(self):
        """
        Init self.checkpoint.
        """
        self.checkpoint = tf.train.Checkpoint(model=self.model,
                                              optimizer=self.optimizer,
                                              sigmas=self.sigmas_variables)

    def init_output_folder_handler(self):
        """
        Init self.output_folder_handler.
        """
        self.output_folder_handler = OutputFolderHandler(self.base_output_folder,
                                                         model_name=self.model.name,
                                                         cv=str(self.cv),
                                                         additional_info=self.additional_output_folder_info)

    def init_datasets(self):
        """
        Init self.dataset_train, self.dataset_train_iter, self.dataset_val.
        """
        dataset_parameters = dict(base_folder=self.local_base_folder,
                                  image_size=self.image_size,
                                  image_spacing=self.image_spacing,
                                  num_landmarks=self.num_landmarks,
                                  normalize_zero_mean_unit_variance=False,
                                  cv=self.cv,
                                  generate_landmarks=True,
                                  generate_landmark_mask=False,
                                  crop_image_top_bottom=True,
                                  crop_randomly_smaller=False,
                                  generate_heatmaps=False,
                                  use_variable_image_size=True,
                                  valid_output_sizes_x=[64, 96],
                                  valid_output_sizes_y=[64, 96],
                                  valid_output_sizes_z=[32, 64, 96, 128, 160, 192, 224, 256],
                                  translate_to_center_landmarks=True,
                                  translate_by_random_factor=True,
                                  data_format=self.data_format,
                                  save_debug_images=self.save_debug_images)

        dataset = Dataset(**dataset_parameters)
        if self.use_pyro_dataset:
            # TODO: adapt hostname, in case this script runs on a remote server
            hostname = socket.gethostname()
            server_name = '@' + hostname + ':52132'
            uri = 'PYRO:verse2020_dataset' + server_name
            print('using pyro uri', uri)
            try:
                self.dataset_train = PyroClientDataset(uri, **dataset_parameters)
            except Exception as e:
                print('Error connecting to server dataset. Start server_dataset_loop.py and set correct hostname, or set self.use_pyro_dataset = False.')
                raise e
        else:
            self.dataset_train = dataset.dataset_train()

        self.dataset_val = dataset.dataset_val()
        self.network_image_size = list(reversed(self.image_size))

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([('image', [1] + self.network_image_size),
                                                  ('landmarks', [self.num_landmarks, 4]),
                                                  ('image_id', tuple())])
        else:
            data_generator_entries = OrderedDict([('image', self.network_image_size + [1]),
                                                  ('landmarks', [self.num_landmarks, 4]),
                                                  ('image_id', tuple())])

        data_generator_types = {'image': tf.float16 if self.use_mixed_precision else tf.float32, 'image_id': tf.string}
        self.dataset_train_iter = DatasetIterator(dataset=self.dataset_train,
                                                  data_names_and_shapes=data_generator_entries,
                                                  data_types=data_generator_types,
                                                  batch_size=self.batch_size)

    def init_loggers(self):
        """
        Init self.loss_metric_logger_train, self.loss_metric_logger_val.
        """
        self.loss_metric_logger_train = LossMetricLogger('train',
                                                         self.output_folder_handler.path('train'),
                                                         self.output_folder_handler.path('train.csv'))
        self.loss_metric_logger_val = LossMetricLogger('test',
                                                       self.output_folder_handler.path('test'),
                                                       self.output_folder_handler.path('test.csv'))

    @tf.function
    def loss_function(self, pred, target, mask=None):
        """
        L2 loss function calculated with prediction and target.
        :param pred: The predicted image.
        :param target: The target image.
        :param mask: If not none, calculate loss only pixels, where mask == 1
        :return: L2 loss of (pred - target) / batch_size
        """
        batch_size, channel_size, image_size = get_batch_channel_image_size(pred, self.data_format, as_tensor=True)
        if mask is not None:
            diff = (pred - target) * mask
        else:
            diff = pred - target
        return tf.nn.l2_loss(diff) / tf.cast(batch_size, tf.float32)

    @tf.function
    def loss_function_sigmas(self, sigmas, valid_landmarks):
        """
        L2 loss function for sigmas. Only calculated for values ver valid_landmarks == 1.
        :param sigmas: Sigma variables.
        :param valid_landmarks: Valid landmarks. Needs to have same shape as sigmas.
        :return: L2 loss of sigmas * valid_landmarks.
        """
        return tf.nn.l2_loss(sigmas * valid_landmarks)

    def call_model_and_loss(self, image, target_landmarks, training):
        """
        Call model and loss.
        :param image: The image to call the model with.
        :param target_landmarks: The target landmarks used for loss calculation.
        :param training: training parameter used for calling the model.
        :return ((prediction, local_prediction, spatial_prediction), losses) tuple
        """
        prediction, local_prediction, spatial_prediction = self.model(image, training=training)
        heatmap_shape = tf.shape(image)[2:] if self.data_format == 'channels_first' else tf.shape(image)[1:-1]
        target_heatmaps = generate_heatmap_target(heatmap_shape,
                                                  target_landmarks,
                                                  self.sigmas,
                                                  scale=1.0,
                                                  normalize=False,
                                                  data_format=self.data_format)
        losses = {}
        losses['loss_net'] = self.loss_function(target=target_heatmaps, pred=prediction)
        if self.sigma_regularization > 0 and self.learnable_sigma:
            losses['loss_sigmas'] = self.sigma_regularization * self.loss_function_sigmas(self.sigmas, target_landmarks[0, :, 0])
        if self.reg_constant > 0:
            losses['loss_reg'] = self.reg_constant * tf.reduce_sum(self.model.losses)
        return (prediction, local_prediction, spatial_prediction), losses

    @tf.function
    def train_step(self):
        """
        Perform a training step.
        """
        image, target_landmarks, image_id = self.dataset_train_iter.get_next()
        with tf.GradientTape() as tape:
            _, losses = self.call_model_and_loss(image, target_landmarks, training=True)
            loss = tf.reduce_sum(list(losses.values()))
            if self.use_mixed_precision:
                scaled_loss = self.optimizer.get_scaled_loss(loss)
        variables = self.model.trainable_weights
        metric_dict = losses
        clip_norm = self.norm_moving_average * 2
        if self.use_mixed_precision:
            scaled_grads = tape.gradient(scaled_loss, variables)
            grads = self.optimizer.get_unscaled_gradients(scaled_grads)
            grads, norm = tf.clip_by_global_norm(grads, clip_norm)
            loss_scale = self.optimizer.loss_scale()
            metric_dict.update({'loss_scale': loss_scale})
        else:
            grads = tape.gradient(loss, variables)
            grads, norm = tf.clip_by_global_norm(grads, clip_norm)
        if tf.math.is_finite(norm):
            alpha = 0.99
            self.norm_moving_average.assign((1-alpha) * tf.minimum(norm, clip_norm) + alpha * self.norm_moving_average)
        metric_dict.update({'norm': norm, 'norm_average': self.norm_moving_average})
        self.optimizer.apply_gradients(zip(grads, variables))

        if self.learnable_sigma:
            self.optimizer_sigma.apply_gradients(zip(grads[-1:], [self.sigmas_variables]))
            metric_dict['mean_sigmas'] = tf.reduce_mean(self.sigmas)

        self.loss_metric_logger_train.update_metrics(metric_dict)

    def test_cropped_image(self, dataset_entry):
        """
        Perform inference on a dataset_entry with the validation network. Performs cropped prediction and merges outputs as maxima.
        :param dataset_entry: A dataset entry from the dataset.
        :return: input image (np.array), target heatmaps (np.array), predicted heatmaps,  transformation (sitk.Transform)
        """
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        transformation = transformations['image']

        full_image = generators['image']
        if self.has_validation_groundtruth:
            landmarks = generators['landmarks']

        image_size_for_tilers = np.minimum(full_image.shape[1:], list(reversed(self.max_image_size_for_cropped_test))).tolist()

        image_size_np = [1] + image_size_for_tilers
        labels_size_np = [self.num_landmarks] + image_size_for_tilers
        image_tiler = ImageTiler(full_image.shape, image_size_np, self.cropped_inc, True, -1)
        landmark_tiler = LandmarkTiler(full_image.shape, image_size_np, self.cropped_inc)
        prediction_tiler = ImageTiler((self.num_landmarks,) + full_image.shape[1:], labels_size_np, self.cropped_inc, True, -np.inf)
        prediction_local_tiler = ImageTiler((self.num_landmarks,) + full_image.shape[1:], labels_size_np, self.cropped_inc, True, -np.inf)
        prediction_spatial_tiler = ImageTiler((self.num_landmarks,) + full_image.shape[1:], labels_size_np, self.cropped_inc, True, -np.inf)
        for image_tiler, landmark_tiler, prediction_tiler, prediction_local_tiler, prediction_spatial_tiler in zip(image_tiler, landmark_tiler, prediction_tiler, prediction_local_tiler, prediction_spatial_tiler):
            current_image = image_tiler.get_current_data(full_image)
            if self.has_validation_groundtruth:
                current_landmarks = landmark_tiler.get_current_data(landmarks)
                (prediction, prediction_local, prediction_spatial), losses = self.call_model_and_loss(np.expand_dims(current_image, axis=0),
                                                                                                      np.expand_dims(current_landmarks, axis=0), training=False)
                self.loss_metric_logger_val.update_metrics(losses)
            else:
                prediction, prediction_local, prediction_spatial = self.model(np.expand_dims(current_image, axis=0), training=False)
            image_tiler.set_current_data(current_image)
            prediction_tiler.set_current_data(np.squeeze(prediction, axis=0))
            prediction_local_tiler.set_current_data(np.squeeze(prediction_local, axis=0))
            prediction_spatial_tiler.set_current_data(np.squeeze(prediction_spatial, axis=0))

        return image_tiler.output_image, prediction_tiler.output_image, prediction_local_tiler.output_image, prediction_spatial_tiler.output_image, transformation

    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')
        vis = LandmarkVisualizationMatplotlib(dim=3,
                                              annotations=dict([(i, f'C{i + 1}') for i in range(7)] +        # 0-6: C1-C7
                                                               [(i, f'T{i - 6}') for i in range(7, 19)] +    # 7-18: T1-12
                                                               [(i, f'L{i - 18}') for i in range(19, 25)] +  # 19-24: L1-6
                                                               [(25, 'T13')]))                               # 25: T13

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3
        heatmap_maxima = HeatmapTest(channel_axis,
                                     False,
                                     return_multiple_maxima=True,
                                     min_max_value=0.05,
                                     smoothing_sigma=2.0)

        with open('possible_successors.pickle', 'rb') as f:
            possible_successors = pickle.load(f)
        with open('units_distances.pickle', 'rb') as f:
            offsets_mean, distances_mean, distances_std = pickle.load(f)
        spine_postprocessing = SpinePostprocessingGraph(num_landmarks=self.num_landmarks,
                                                        possible_successors=possible_successors,
                                                        offsets_mean=offsets_mean,
                                                        distances_mean=distances_mean,
                                                        distances_std=distances_std,
                                                        bias=2.0,
                                                        l=0.2)

        landmark_statistics = LandmarkStatistics()
        landmarks = {}
        landmark_statistics_no_postprocessing = LandmarkStatistics()
        landmarks_no_postprocessing = {}
        all_local_maxima_landmarks = {}
        num_entries = self.dataset_val.num_entries()
        for _ in tqdm(range(num_entries), desc='Testing'):
            dataset_entry = self.dataset_val.get_next()
            current_id = dataset_entry['id']['image_id']
            datasources = dataset_entry['datasources']
            input_image = datasources['image']
            if self.has_validation_groundtruth:
                target_landmarks = datasources['landmarks']
            else:
                target_landmarks = None

            image, prediction, prediction_local, prediction_spatial, transformation = self.test_cropped_image(dataset_entry)

            origin = transformation.TransformPoint(np.zeros(3, np.float64))
            if self.save_output_images:
                heatmap_normalization_mode = (-1, 1)
                image_type = np.uint8
                utils.io.image.write_multichannel_np(image,self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_input.mha'), output_normalization_mode='min_max', sitk_image_output_mode='vector', data_format=self.data_format, image_type=image_type, spacing=self.image_spacing, origin=origin)
                utils.io.image.write_multichannel_np(prediction, self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_prediction.mha'), output_normalization_mode=heatmap_normalization_mode, sitk_image_output_mode='vector', data_format=self.data_format, image_type=image_type, spacing=self.image_spacing, origin=origin)
                utils.io.image.write_multichannel_np(prediction_local, self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_prediction_local.mha'), output_normalization_mode=heatmap_normalization_mode, sitk_image_output_mode='vector', data_format=self.data_format, image_type=image_type, spacing=self.image_spacing, origin=origin)
                utils.io.image.write_multichannel_np(prediction_spatial, self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_prediction_spatial.mha'), output_normalization_mode=heatmap_normalization_mode, sitk_image_output_mode='vector', data_format=self.data_format, image_type=image_type, spacing=self.image_spacing, origin=origin)

            local_maxima_landmarks = heatmap_maxima.get_landmarks(prediction, input_image, self.image_spacing, transformation)

            # landmarks without postprocessing are the first local maxima (with the largest value)
            curr_landmarks_no_postprocessing = [l[0] if len(l) > 0 else Landmark(coords=[np.nan] * 3, is_valid=False)  for l in local_maxima_landmarks]
            landmarks_no_postprocessing[current_id] = curr_landmarks_no_postprocessing

            if self.has_validation_groundtruth:
                landmark_statistics_no_postprocessing.add_landmarks(current_id, curr_landmarks_no_postprocessing, target_landmarks)
                vis.visualize_landmark_projections(input_image, target_landmarks, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks_gt.png'))
                vis.visualize_prediction_groundtruth_projections(input_image, curr_landmarks_no_postprocessing, target_landmarks, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks.png'))
            else:
                vis.visualize_landmark_projections(input_image, curr_landmarks_no_postprocessing, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks.png'))

            if self.evaluate_landmarks_postprocessing:
                try:
                    local_maxima_landmarks = add_landmarks_from_neighbors(local_maxima_landmarks)
                    curr_landmarks = spine_postprocessing.solve_local_heatmap_maxima(local_maxima_landmarks)
                    curr_landmarks = reshift_landmarks(curr_landmarks)
                    curr_landmarks = filter_landmarks_top_bottom(curr_landmarks, input_image)
                except Exception:
                    print('error in postprocessing', current_id)
                    curr_landmarks = curr_landmarks_no_postprocessing
                landmarks[current_id] = curr_landmarks

                if self.has_validation_groundtruth:
                    landmark_statistics.add_landmarks(current_id, curr_landmarks, target_landmarks)
                    vis.visualize_prediction_groundtruth_projections(input_image, curr_landmarks, target_landmarks, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks_pp.png'))
                else:
                    vis.visualize_landmark_projections(input_image, curr_landmarks, filename=self.output_folder_handler.path_for_iteration(self.current_iter, current_id + '_landmarks_pp.png'))

        utils.io.landmark.save_points_csv(landmarks, self.output_folder_handler.path_for_iteration(self.current_iter, 'points.csv'))
        utils.io.landmark.save_points_csv(landmarks_no_postprocessing, self.output_folder_handler.path_for_iteration(self.current_iter, 'points_no_postprocessing.csv'))

        # finalize loss values
        if self.has_validation_groundtruth:
            summary_values = OrderedDict()
            if self.evaluate_landmarks_postprocessing:
                print(landmark_statistics.get_pe_overview_string())
                print(landmark_statistics.get_correct_id_string(20.0))
                overview_string = landmark_statistics.get_overview_string([2, 2.5, 3, 4, 10, 20], 10, 20.0)
                utils.io.text.save_string_txt(overview_string, self.output_folder_handler.path_for_iteration(self.current_iter, 'eval.txt'))
                summary_values.update(OrderedDict(zip(['pe_mean', 'pe_stdev', 'pe_median', 'num_correct'], list(landmark_statistics.get_pe_statistics()) + [landmark_statistics.get_num_correct_id(20)])))
            print(landmark_statistics_no_postprocessing.get_pe_overview_string())
            print(landmark_statistics_no_postprocessing.get_correct_id_string(20.0))
            overview_string = landmark_statistics_no_postprocessing.get_overview_string([2, 2.5, 3, 4, 10, 20], 10, 20.0)
            utils.io.text.save_string_txt(overview_string, self.output_folder_handler.path_for_iteration(self.current_iter, 'eval_no_postprocessing.txt'))
            summary_values.update(OrderedDict(zip(['pe_mean_np', 'pe_stdev_np', 'pe_median_np', 'num_correct_np'], list(landmark_statistics_no_postprocessing.get_pe_statistics()) + [landmark_statistics_no_postprocessing.get_num_correct_id(20)])))
            self.loss_metric_logger_val.update_metrics(summary_values)

            # finalize loss values
        self.loss_metric_logger_val.finalize(self.current_iter)
Пример #5
0
class MainLoop(MainLoopBase):
    def __init__(self, config):
        """
        Initializer.
        :param cv: The cv fold. 0, 1, 2 for CV; 'train_all' for training on whole dataset.
        :param config: config dictionary
        """
        super().__init__()
        gpu_available = tf.test.gpu_device_name() != ''
        self.use_mixed_precision = gpu_available
        if self.use_mixed_precision:
            policy = mixed_precision.Policy('mixed_float16')
            mixed_precision.set_policy(policy)
        self.cv = config.cv
        self.config = config
        self.batch_size = 1
        self.num_landmarks = 26
        self.data_format = 'channels_last'
        self.network_parameters = OrderedDict(
            num_filters_base=config.num_filters_base,
            activation=config.activation,
            spatial_downsample=config.spatial_downsample,
            local_activation=config.local_activation,
            spatial_activation=config.spatial_activation,
            num_levels=config.num_levels,
            data_format=self.data_format)
        if config.model == 'scn':
            self.network = SpatialConfigurationNet
        if config.model == 'unet':
            self.network = Unet

        self.evaluate_landmarks_postprocessing = True
        self.use_pyro_dataset = True
        self.save_output_images = False
        self.save_debug_images = False
        self.image_folder = config.image_folder
        self.setup_folder = config.setup_folder
        self.output_folder = config.output_folder
        self.load_model_filenames = config.load_model_filenames
        self.image_size = [None, None, None]
        self.image_spacing = [config.spacing] * 3
        self.max_image_size_for_cropped_test = [128, 128, 448]
        self.cropped_inc = [0, 128, 0, 0]
        self.heatmap_size = self.image_size
        images_files = sorted(glob(os.path.join(self.image_folder,
                                                '*.nii.gz')))
        self.image_id_list = list(
            map(lambda filename: os.path.basename(filename)[:-len('.nii.gz')],
                images_files))

        self.landmark_labels = [i + 1 for i in range(25)] + [28]
        self.landmark_mapping = dict([(i, self.landmark_labels[i])
                                      for i in range(26)])
        self.landmark_mapping_inverse = dict([(self.landmark_labels[i], i)
                                              for i in range(26)])

        #if self.data_format == 'channels_first':
        #self.call_model = tf.function(self.call_model, input_signature=[tf.TensorSpec(tf.TensorShape([1, 1] + list(reversed(self.image_size))), tf.float16 if self.use_mixed_precision else tf.float32)])
        #else:
        #self.call_model = tf.function(self.call_model, input_signature=[tf.TensorSpec(tf.TensorShape([1] + list(reversed(self.image_size))) + [1], tf.float16 if self.use_mixed_precision else tf.float32)])

    def init_model(self):
        self.model = self.network(num_labels=self.num_landmarks,
                                  **self.network_parameters)

    def init_checkpoint(self):
        self.checkpoint = tf.train.Checkpoint(model=self.model)

    def init_output_folder_handler(self):
        self.output_folder_handler = OutputFolderHandler(self.output_folder,
                                                         use_timestamp=False,
                                                         files_to_copy=[])

    def init_datasets(self):
        dataset_parameters = dict(
            image_base_folder=self.image_folder,
            setup_base_folder=self.setup_folder,
            image_size=self.image_size,
            image_spacing=self.image_spacing,
            num_landmarks=self.num_landmarks,
            normalize_zero_mean_unit_variance=False,
            cv=self.cv,
            input_gaussian_sigma=0.75,
            crop_image_top_bottom=True,
            use_variable_image_size=True,
            load_spine_bbs=True,
            valid_output_sizes_x=[64, 96],
            valid_output_sizes_y=[64, 96],
            valid_output_sizes_z=[64, 96, 128, 160, 192, 224, 256, 288, 320],
            translate_to_center_landmarks=True,
            translate_by_random_factor=True,
            data_format=self.data_format,
            save_debug_images=self.save_debug_images)

        dataset = Dataset(**dataset_parameters)
        self.dataset_val = dataset.dataset_val()
        self.network_image_size = list(reversed(self.image_size))

    def call_model(self, image):
        return self.model(image, training=False)

    def convert_landmarks_to_verse_indexing(self, landmarks, image):
        new_landmarks = []
        spacing = np.array(image.GetSpacing())
        size = np.array(image.GetSize())
        for landmark in landmarks:
            new_landmark = deepcopy(landmark)
            if not landmark.is_valid:
                new_landmarks.append(new_landmark)
                continue
            coords = np.array(landmark.coords.tolist())
            verse_coords = np.array([
                coords[1], size[2] * spacing[2] - coords[2],
                size[0] * spacing[0] - coords[0]
            ])
            new_landmark.coords = verse_coords
            new_landmarks.append(new_landmark)
        return new_landmarks

    def save_landmarks_verse_json(self, landmarks, filename):
        verse_landmarks_list = []
        for i, landmark in enumerate(landmarks):
            if landmark.is_valid:
                verse_landmarks_list.append({
                    'label': self.landmark_mapping[i],
                    'X': landmark.coords[0],
                    'Y': landmark.coords[1],
                    'Z': landmark.coords[2]
                })
        with open(filename, 'w') as f:
            json.dump(verse_landmarks_list, f)

    def save_valid_landmarks_list(self, landmarks_dict, filename):
        valid_landmarks = {}
        for image_id, landmarks in landmarks_dict.items():
            current_valid_landmarks = []
            for landmark_id, landmark in enumerate(landmarks):
                if landmark.is_valid:
                    current_valid_landmarks.append(landmark_id)
            valid_landmarks[image_id] = current_valid_landmarks
        utils.io.text.save_dict_csv(valid_landmarks, filename)

    def test_cropped_image(self, dataset_entry):
        """
        Perform inference on a dataset_entry with the validation network. Performs cropped prediction and merges outputs as maxima.
        :param dataset_entry: A dataset entry from the dataset.
        :return: input image (np.array), target heatmaps (np.array), predicted heatmaps,  transformation (sitk.Transform)
        """
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        transformation = transformations['image']

        full_image = generators['image']

        if self.data_format == 'channels_first':
            image_size_for_tilers = np.minimum(
                full_image.shape[1:],
                list(reversed(self.max_image_size_for_cropped_test))).tolist()
            image_size_np = [1] + image_size_for_tilers
            labels_size_np = [self.num_landmarks] + image_size_for_tilers
            image_tiler = ImageTiler(full_image.shape, image_size_np,
                                     self.cropped_inc, True, -1)
            prediction_tiler = ImageTiler(
                (self.num_landmarks, ) + full_image.shape[1:], labels_size_np,
                self.cropped_inc, True, -np.inf)
            prediction_local_tiler = ImageTiler(
                (self.num_landmarks, ) + full_image.shape[1:], labels_size_np,
                self.cropped_inc, True, -np.inf)
            prediction_spatial_tiler = ImageTiler(
                (self.num_landmarks, ) + full_image.shape[1:], labels_size_np,
                self.cropped_inc, True, -np.inf)
        else:
            image_size_for_tilers = np.minimum(
                full_image.shape[:-1],
                list(reversed(self.max_image_size_for_cropped_test))).tolist()
            image_size_np = image_size_for_tilers + [1]
            labels_size_np = image_size_for_tilers + [self.num_landmarks]
            image_tiler = ImageTiler(full_image.shape, image_size_np,
                                     self.cropped_inc, True, -1)
            prediction_tiler = ImageTiler(
                full_image.shape[:-1] + (self.num_landmarks, ), labels_size_np,
                self.cropped_inc, True, -np.inf)
            prediction_local_tiler = ImageTiler(
                full_image.shape[:-1] + (self.num_landmarks, ), labels_size_np,
                self.cropped_inc, True, -np.inf)
            prediction_spatial_tiler = ImageTiler(
                full_image.shape[:-1] + (self.num_landmarks, ), labels_size_np,
                self.cropped_inc, True, -np.inf)

        for image_tiler, prediction_tiler, prediction_local_tiler, prediction_spatial_tiler in zip(
                image_tiler, prediction_tiler, prediction_local_tiler,
                prediction_spatial_tiler):
            current_image = image_tiler.get_current_data(full_image)
            predictions = []
            predictions_local = []
            predictions_spatial = []
            for load_model_filename in self.load_model_filenames:
                if len(self.load_model_filenames) > 1:
                    self.load_model(load_model_filename)
                prediction, prediction_local, prediction_spatial = self.call_model(
                    np.expand_dims(current_image, axis=0))
                predictions.append(prediction.numpy())
                predictions_local.append(prediction_local.numpy())
                predictions_spatial.append(prediction_spatial.numpy())
            prediction = np.mean(predictions, axis=0)
            prediction_local = np.mean(predictions_local, axis=0)
            prediction_spatial = np.mean(predictions_spatial, axis=0)
            image_tiler.set_current_data(current_image)
            prediction_tiler.set_current_data(np.squeeze(prediction, axis=0))
            prediction_local_tiler.set_current_data(
                np.squeeze(prediction_local, axis=0))
            prediction_spatial_tiler.set_current_data(
                np.squeeze(prediction_spatial, axis=0))

        return image_tiler.output_image, prediction_tiler.output_image, prediction_local_tiler.output_image, prediction_spatial_tiler.output_image, transformation

    def reshift_landmarks(self, curr_landmarks):
        if (not curr_landmarks[0].is_valid) and curr_landmarks[7].is_valid:
            if (not curr_landmarks[6].is_valid) and curr_landmarks[5].is_valid:
                # shift c indizes up
                print('shift c indizes up')
                curr_landmarks = [
                    Landmark([np.nan] * 3, is_valid=False)
                ] + curr_landmarks[0:5] + curr_landmarks[6:26]
        if (not curr_landmarks[7].is_valid) and curr_landmarks[19].is_valid:
            if (not curr_landmarks[18].is_valid
                ) and curr_landmarks[17].is_valid:
                # shift l indizes up
                print('shift t indizes up')
                curr_landmarks = curr_landmarks[0:7] + [
                    Landmark([np.nan] * 3, is_valid=False)
                ] + curr_landmarks[7:18] + curr_landmarks[19:26]
            elif curr_landmarks[25].is_valid:
                # shift l indizes down
                print('shift t indizes down')
                curr_landmarks = curr_landmarks[0:7] + curr_landmarks[8:19] + [
                    curr_landmarks[25]
                ] + curr_landmarks[19:25] + [
                    Landmark([np.nan] * 3, is_valid=False)
                ]
        return curr_landmarks

    def filter_landmarks_top_bottom(self, curr_landmarks, input_image):
        image_extent = [
            spacing * size for spacing, size in zip(input_image.GetSpacing(),
                                                    input_image.GetSize())
        ]
        filtered_landmarks = []
        z_distance_top_bottom = 10
        for l in curr_landmarks:
            if z_distance_top_bottom < l.coords[
                    2] < image_extent[2] - z_distance_top_bottom:
                filtered_landmarks.append(l)
            else:
                filtered_landmarks.append(
                    Landmark(coords=[np.nan] * 3, is_valid=False))
        return filtered_landmarks

    def add_landmarks_from_neighbors(self, local_maxima_landmarks):
        local_maxima_landmarks = deepcopy(local_maxima_landmarks)
        duplicate_penalty = 0.1
        for i in range(2, 6):
            local_maxima_landmarks[i + 1].extend([
                Landmark(coords=l.coords, value=l.value * duplicate_penalty)
                for l in local_maxima_landmarks[i]
            ])
            local_maxima_landmarks[i].extend([
                Landmark(coords=l.coords, value=l.value * duplicate_penalty)
                for l in local_maxima_landmarks[i + 1]
            ])
        for i in range(8, 18):
            local_maxima_landmarks[i + 1].extend([
                Landmark(coords=l.coords, value=l.value * duplicate_penalty)
                for l in local_maxima_landmarks[i]
            ])
            local_maxima_landmarks[i].extend([
                Landmark(coords=l.coords, value=l.value * duplicate_penalty)
                for l in local_maxima_landmarks[i + 1]
            ])
        local_maxima_landmarks[25].extend([
            Landmark(coords=l.coords, value=l.value)
            for l in local_maxima_landmarks[18]
        ])
        local_maxima_landmarks[18].extend([
            Landmark(coords=l.coords, value=l.value)
            for l in local_maxima_landmarks[25]
        ])
        for i in range(20, 24):
            local_maxima_landmarks[i + 1].extend([
                Landmark(coords=l.coords, value=l.value * duplicate_penalty)
                for l in local_maxima_landmarks[i]
            ])
            local_maxima_landmarks[i].extend([
                Landmark(coords=l.coords, value=l.value * duplicate_penalty)
                for l in local_maxima_landmarks[i + 1]
            ])
        return local_maxima_landmarks

    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')

        if len(self.load_model_filenames) == 1:
            self.load_model(self.load_model_filenames[0])

        vis = LandmarkVisualizationMatplotlib(
            annotations=dict([(i, f'C{i + 1}')
                              for i in range(7)] + [(i, f'T{i - 6}')
                                                    for i in range(7, 19)] +
                             [(i, f'L{i - 18}')
                              for i in range(19, 25)] + [(25, 'T13')]))

        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3
        heatmap_maxima = HeatmapTest(channel_axis,
                                     False,
                                     return_multiple_maxima=True,
                                     min_max_value=0.05,
                                     smoothing_sigma=2.0)

        with open('possible_successors.pickle', 'rb') as f:
            possible_successors = pickle.load(f)
        with open('units_distances.pickle', 'rb') as f:
            offsets_mean, distances_mean, distances_std = pickle.load(f)
        #spine_postprocessing = SpinePostprocessingGraph(num_landmarks=self.num_landmarks,
        #                                           bias=2.0)
        spine_postprocessing = SpinePostprocessingGraph(
            num_landmarks=self.num_landmarks,
            possible_successors=possible_successors,
            offsets_mean=offsets_mean,
            distances_mean=distances_mean,
            distances_std=distances_std,
            bias=2.0,
            l=0.2)

        landmarks = {}
        landmarks_no_postprocessing = {}
        for current_id in tqdm(self.image_id_list, desc='Testing'):
            try:
                dataset_entry = self.dataset_val.get({'image_id': current_id})
                print(current_id)
                datasources = dataset_entry['datasources']
                input_image = datasources['image']

                image, prediction, prediction_local, prediction_spatial, transformation = self.test_cropped_image(
                    dataset_entry)

                origin = transformation.TransformPoint(np.zeros(3, np.float64))
                if self.save_output_images:
                    heatmap_normalization_mode = (-1, 1)
                    image_type = np.uint8
                    utils.io.image.write_multichannel_np(
                        image,
                        self.output_folder_handler.path(
                            'output', current_id + '_input.mha'),
                        output_normalization_mode='min_max',
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=image_type,
                        spacing=self.image_spacing,
                        origin=origin)
                    utils.io.image.write_multichannel_np(
                        prediction,
                        self.output_folder_handler.path(
                            'output', current_id + '_prediction.mha'),
                        output_normalization_mode=heatmap_normalization_mode,
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=image_type,
                        spacing=self.image_spacing,
                        origin=origin)
                    utils.io.image.write_multichannel_np(
                        prediction,
                        self.output_folder_handler.path(
                            'output', current_id + '_prediction_rgb.mha'),
                        output_normalization_mode=(0, 1),
                        channel_layout_mode='channel_rgb',
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=image_type,
                        spacing=self.image_spacing,
                        origin=origin)
                    utils.io.image.write_multichannel_np(
                        prediction_local,
                        self.output_folder_handler.path(
                            'output', current_id + '_prediction_local.mha'),
                        output_normalization_mode=heatmap_normalization_mode,
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=image_type,
                        spacing=self.image_spacing,
                        origin=origin)
                    utils.io.image.write_multichannel_np(
                        prediction_spatial,
                        self.output_folder_handler.path(
                            'output', current_id + '_prediction_spatial.mha'),
                        output_normalization_mode=heatmap_normalization_mode,
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=image_type,
                        spacing=self.image_spacing,
                        origin=origin)

                local_maxima_landmarks = heatmap_maxima.get_landmarks(
                    prediction, input_image, self.image_spacing,
                    transformation)

                curr_landmarks_no_postprocessing = [
                    l[0] for l in local_maxima_landmarks
                ]
                landmarks_no_postprocessing[
                    current_id] = curr_landmarks_no_postprocessing

                local_maxima_landmarks = self.add_landmarks_from_neighbors(
                    local_maxima_landmarks)
                curr_landmarks = spine_postprocessing.solve_local_heatmap_maxima(
                    local_maxima_landmarks)
                curr_landmarks = self.reshift_landmarks(curr_landmarks)
                curr_landmarks = self.filter_landmarks_top_bottom(
                    curr_landmarks, input_image)
                landmarks[current_id] = curr_landmarks

                if self.save_output_images:
                    vis.visualize_projections(
                        input_image,
                        curr_landmarks_no_postprocessing,
                        None,
                        filename=self.output_folder_handler.path(
                            'output', current_id + '_landmarks.png'))
                    vis.visualize_projections(
                        input_image,
                        curr_landmarks,
                        None,
                        filename=self.output_folder_handler.path(
                            'output', current_id + '_landmarks_pp.png'))

                verse_landmarks = self.convert_landmarks_to_verse_indexing(
                    curr_landmarks, input_image)
                self.save_landmarks_verse_json(
                    verse_landmarks,
                    self.output_folder_handler.path(current_id + '_ctd.json'))
            except:
                print('ERROR predicting', current_id)
                pass

        utils.io.landmark.save_points_csv(
            landmarks, self.output_folder_handler.path('landmarks.csv'))
        utils.io.landmark.save_points_csv(
            landmarks_no_postprocessing,
            self.output_folder_handler.path('landmarks_no_postprocessing.csv'))
        self.save_valid_landmarks_list(
            landmarks, self.output_folder_handler.path('valid_landmarks.csv'))
Пример #6
0
class MainLoop(MainLoopBase):
    def __init__(self, cv, config):
        """
        Initializer.
        :param cv: The cv fold. 0, 1, 2 for CV; 'train_all' for training on whole dataset.
        :param config: config dictionary
        """
        super().__init__()
        self.use_mixed_precision = True
        if self.use_mixed_precision:
            policy = mixed_precision.Policy('mixed_float16')
            mixed_precision.set_policy(policy)
        self.cv = cv
        self.config = config
        self.batch_size = 1
        self.learning_rate = config.learning_rate
        self.max_iter = 50000
        self.test_iter = 5000
        self.disp_iter = 100
        self.snapshot_iter = 5000
        self.test_initialization = False
        self.reg_constant = 0.0  #005
        self.use_background = True
        self.num_labels = 1
        self.num_labels_all = 27
        self.data_format = 'channels_first'
        self.network_parameters = OrderedDict(
            num_filters_base=config.num_filters_base,
            activation=config.activation,
            dropout_ratio=config.dropout_ratio,
            num_levels=config.num_levels,
            data_format=self.data_format)
        if config.model == 'unet':
            self.network = Unet
        self.clip_gradient_global_norm = 100.0

        self.use_pyro_dataset = True
        self.save_output_images = True
        self.save_debug_images = False
        self.has_validation_groundtruth = cv in [0, 1, 2]
        self.local_base_folder = '../verse2020_dataset'
        self.image_size = [128, 128, 96]
        self.image_spacing = [config.spacing] * 3
        self.heatmap_size = self.image_size
        self.base_output_folder = './output/vertebrae_segmentation/'
        self.additional_output_folder_info = config.info

        if self.data_format == 'channels_first':
            self.call_model_and_loss = tf.function(
                self.call_model_and_loss,
                input_signature=[
                    tf.TensorSpec(
                        tf.TensorShape([1, 2] +
                                       list(reversed(self.image_size))),
                        tf.float16
                        if self.use_mixed_precision else tf.float32),
                    tf.TensorSpec(
                        tf.TensorShape([1, 1] +
                                       list(reversed(self.image_size))),
                        tf.uint8),
                    tf.TensorSpec(tf.TensorShape(None), tf.bool)
                ])
        else:
            self.call_model_and_loss = tf.function(
                self.call_model_and_loss,
                input_signature=[
                    tf.TensorSpec(
                        tf.TensorShape([1] + list(reversed(self.image_size))) +
                        [2], tf.float16
                        if self.use_mixed_precision else tf.float32),
                    tf.TensorSpec(
                        tf.TensorShape([1] + list(reversed(self.image_size))) +
                        [1], tf.uint8),
                    tf.TensorSpec(tf.TensorShape(None), tf.bool)
                ])

        self.dice_names = ['mean_dice'] + list(
            map(lambda x: 'dice_{}'.format(x), range(1, self.num_labels_all)))

        self.setup_base_folder = os.path.join(self.local_base_folder, 'setup')
        if cv in [0, 1, 2]:
            self.cv_folder = os.path.join(self.setup_base_folder,
                                          os.path.join('cv', str(cv)))
            self.test_file = os.path.join(self.cv_folder, 'val.txt')
        else:
            self.test_file = os.path.join(self.setup_base_folder,
                                          'train_all.txt')
        self.valid_landmarks_file = os.path.join(self.setup_base_folder,
                                                 'valid_landmarks.csv')
        self.test_id_list = utils.io.text.load_list(self.test_file)
        self.valid_landmarks = utils.io.text.load_dict_csv(
            self.valid_landmarks_file, squeeze=False)

        self.landmark_labels = [i + 1 for i in range(25)] + [28]
        self.landmark_mapping = dict([(i, self.landmark_labels[i])
                                      for i in range(26)])
        self.landmark_mapping_inverse = dict([(self.landmark_labels[i], i)
                                              for i in range(26)])

    def init_model(self):
        """
        Init self.model.
        """
        # create sigmas variable
        self.norm_moving_average = tf.Variable(1.0)
        self.model = self.network(num_labels=self.num_labels,
                                  **self.network_parameters)

    def init_optimizer(self):
        """
        Init self.optimizer.
        """
        self.learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
            self.learning_rate, self.max_iter, 0.1)
        self.optimizer = tf.keras.optimizers.Adam(
            learning_rate=self.learning_rate, amsgrad=True)
        if self.use_mixed_precision:
            self.optimizer = mixed_precision.LossScaleOptimizer(
                self.optimizer,
                loss_scale=tf.mixed_precision.experimental.DynamicLossScale(
                    initial_loss_scale=2**15, increment_period=1000))
        #self.optimizer_sigma = tf.keras.optimizers.Adam(learning_rate=self.learning_rate*10)

    def init_checkpoint(self):
        """
        Init self.checkpoint.
        """
        self.checkpoint = tf.train.Checkpoint(model=self.model,
                                              optimizer=self.optimizer)

    def init_output_folder_handler(self):
        """
        Init self.output_folder_handler.
        """
        self.output_folder_handler = OutputFolderHandler(
            self.base_output_folder,
            model_name=self.model.name,
            cv=str(self.cv),
            additional_info=self.additional_output_folder_info)

    def init_datasets(self):
        """
        Init self.dataset_train, self.dataset_train_iter, self.dataset_val.
        """
        dataset_parameters = dict(base_folder=self.local_base_folder,
                                  image_size=self.image_size,
                                  image_spacing=self.image_spacing,
                                  normalize_zero_mean_unit_variance=False,
                                  cv=self.cv,
                                  label_gaussian_sigma=1.0,
                                  random_translation=10.0,
                                  random_rotate=0.5,
                                  heatmap_sigma=3.0,
                                  generate_single_vertebrae_heatmap=True,
                                  generate_single_vertebrae=True,
                                  output_image_type=np.float16
                                  if self.use_mixed_precision else np.float32,
                                  data_format=self.data_format,
                                  save_debug_images=self.save_debug_images)

        dataset = Dataset(**dataset_parameters)
        if self.use_pyro_dataset:
            # TODO: adapt hostname, in case this script runs on a remote server
            hostname = socket.gethostname()
            server_name = '@' + hostname + ':52132'
            uri = 'PYRO:verse2020_dataset' + server_name
            print('using pyro uri', uri)
            try:
                self.dataset_train = PyroClientDataset(uri,
                                                       **dataset_parameters)
            except Exception as e:
                print(
                    'Error connecting to server dataset. Start server_dataset_loop.py and set correct hostname, or set self.use_pyro_dataset = False.'
                )
                raise e
        else:
            self.dataset_train = dataset.dataset_train()

        self.dataset_val = dataset.dataset_val()
        self.network_image_size = list(reversed(self.image_size))

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image', [1] + self.network_image_size),
                ('single_label', [self.num_labels] + self.network_image_size),
                ('single_heatmap', [1] + self.network_image_size)
            ])
        else:
            data_generator_entries = OrderedDict([
                ('image', self.network_image_size + [1]),
                ('single_label', self.network_image_size + [self.num_labels]),
                ('single_heatmap', self.network_image_size + [1])
            ])

        data_generator_types = {
            'image': tf.float16 if self.use_mixed_precision else tf.float32,
            'single_heatmap':
            tf.float16 if self.use_mixed_precision else tf.float32,
            'single_label': tf.uint8
        }
        self.dataset_train_iter = DatasetIterator(
            dataset=self.dataset_train,
            data_names_and_shapes=data_generator_entries,
            data_types=data_generator_types,
            batch_size=self.batch_size,
            n_threads=4)

    def init_loggers(self):
        """
        Init self.loss_metric_logger_train, self.loss_metric_logger_val.
        """
        self.loss_metric_logger_train = LossMetricLogger(
            'train', self.output_folder_handler.path('train'),
            self.output_folder_handler.path('train.csv'))
        self.loss_metric_logger_val = LossMetricLogger(
            'test', self.output_folder_handler.path('test'),
            self.output_folder_handler.path('test.csv'))

    @tf.function
    def loss_function(self, pred, target):
        """
        L2 loss function calculated with prediction and target.
        :param pred: The predicted image.
        :param target: The target image.
        :return: L2 loss of (pred - target) / batch_size
        """
        return tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.cast(target, tf.float32),
                logits=tf.cast(pred, tf.float32)))

    def call_model_and_loss(self, image, target_labels, training):
        """
        Call model and loss.
        :param image: The image to call the model with.
        :param target_labels: The target labels used for loss calculation.
        :param training: training parameter used for calling the model.
        :return (prediction, losses) tuple
        """
        prediction = self.model(image, training=training)
        losses = {}
        losses['loss_net'] = self.loss_function(target=target_labels,
                                                pred=prediction)
        return prediction, losses

    @tf.function
    def train_step(self):
        """
        Perform a training step.
        """
        image, single_label, single_heatmap = self.dataset_train_iter.get_next(
        )
        image_heatmap_concat = tf.concat(
            [image, single_heatmap],
            axis=1 if self.data_format == 'channels_first' else -1)
        with tf.GradientTape() as tape:
            _, losses = self.call_model_and_loss(image_heatmap_concat,
                                                 single_label,
                                                 training=True)
            if self.reg_constant > 0:
                losses['loss_reg'] = self.reg_constant * tf.reduce_sum(
                    self.model.losses)
            loss = tf.reduce_sum(list(losses.values()))
            if self.use_mixed_precision:
                scaled_loss = self.optimizer.get_scaled_loss(loss)
        variables = self.model.trainable_weights
        metric_dict = losses
        clip_norm = self.norm_moving_average * 2
        if self.use_mixed_precision:
            scaled_grads = tape.gradient(scaled_loss, variables)
            grads = self.optimizer.get_unscaled_gradients(scaled_grads)
            grads, norm = tf.clip_by_global_norm(grads, clip_norm)
            loss_scale = self.optimizer.loss_scale()
            metric_dict.update({'loss_scale': loss_scale})
        else:
            grads = tape.gradient(loss, variables)
            grads, norm = tf.clip_by_global_norm(grads, clip_norm)
        if tf.math.is_finite(norm):
            alpha = 0.99
            self.norm_moving_average.assign((1 - alpha) *
                                            tf.minimum(norm, clip_norm) +
                                            alpha * self.norm_moving_average)
        metric_dict.update({
            'norm': norm,
            'norm_average': self.norm_moving_average
        })
        self.optimizer.apply_gradients(zip(grads, variables))
        self.loss_metric_logger_train.update_metrics(metric_dict)

    def test_full_image(self, dataset_entry):
        """
        Perform inference on a dataset_entry with the validation network.
        :param dataset_entry: A dataset entry from the dataset.
        :return: input image (np.array), network prediction (np.array), transformation (sitk.Transform)
        """
        generators = dataset_entry['generators']
        transformations = dataset_entry['transformations']
        image = np.expand_dims(generators['image'], axis=0)
        single_heatmap = np.expand_dims(generators['single_heatmap'], axis=0)
        image_heatmap_concat = tf.concat(
            [image, single_heatmap],
            axis=1 if self.data_format == 'channels_first' else -1)
        if self.has_validation_groundtruth:
            single_label = np.expand_dims(generators['single_label'], axis=0)
            prediction, losses = self.call_model_and_loss(image_heatmap_concat,
                                                          single_label,
                                                          training=False)
            self.loss_metric_logger_val.update_metrics(losses)
        else:
            prediction = self.model(image_heatmap_concat, training=False)
        prediction = np.squeeze(prediction, axis=0)
        transformation = transformations['image']
        image = generators['image']

        return image, prediction, transformation

    def test(self):
        """
        The test function. Performs inference on the the validation images and calculates the loss.
        """
        print('Testing...')
        channel_axis = 0
        if self.data_format == 'channels_last':
            channel_axis = 3
        segmentation_statistics = SegmentationStatistics(
            list(range(1, self.num_labels_all)),
            self.output_folder_handler.path_for_iteration(self.current_iter),
            metrics=OrderedDict([('dice', DiceMetric())]))
        filter_largest_cc = True

        # iterate over all images
        for image_id in tqdm(self.test_id_list, desc='Testing'):
            first = True
            prediction_labels_np = None
            prediction_max_value_np = None
            input_image = None
            groundtruth = None
            # iterate over all valid landmarks
            for landmark_id in self.valid_landmarks[image_id]:
                dataset_entry = self.dataset_val.get({
                    'image_id': image_id,
                    'landmark_id': landmark_id
                })
                if first:
                    input_image = dataset_entry['datasources']['image']
                    if self.has_validation_groundtruth:
                        groundtruth = dataset_entry['datasources']['labels']
                    prediction_labels_np = np.zeros(list(
                        reversed(input_image.GetSize())),
                                                    dtype=np.uint8)
                    prediction_max_value_np = np.ones(list(
                        reversed(input_image.GetSize())),
                                                      dtype=np.float32) * 0.5
                    first = False

                image, prediction, transformation = self.test_full_image(
                    dataset_entry)
                del dataset_entry

                origin = transformation.TransformPoint(np.zeros(3, np.float64))
                max_index = transformation.TransformPoint(
                    np.array(self.image_size, np.float64) *
                    np.array(self.image_spacing, np.float64))

                if self.save_output_images:
                    utils.io.image.write_multichannel_np(
                        image,
                        self.output_folder_handler.path(
                            'output',
                            image_id + '_' + landmark_id + '_input.mha'),
                        output_normalization_mode='min_max',
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=np.uint8,
                        spacing=self.image_spacing,
                        origin=origin)
                    utils.io.image.write_multichannel_np(
                        prediction,
                        self.output_folder_handler.path(
                            'output',
                            image_id + '_' + landmark_id + '_prediction.mha'),
                        output_normalization_mode=(0, 1),
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=np.uint8,
                        spacing=self.image_spacing,
                        origin=origin)
                del image
                prediction = prediction.astype(np.float32)
                prediction_resampled_sitk = utils.sitk_image.transform_np_output_to_sitk_input(
                    output_image=prediction,
                    output_spacing=self.image_spacing,
                    channel_axis=channel_axis,
                    input_image_sitk=input_image,
                    transform=transformation,
                    interpolator='cubic',
                    output_pixel_type=sitk.sitkFloat32)
                del prediction
                prediction_resampled_np = utils.sitk_np.sitk_to_np(
                    prediction_resampled_sitk[0])
                if self.save_output_images:
                    utils.io.image.write_multichannel_np(
                        prediction_resampled_np,
                        self.output_folder_handler.path(
                            'output', image_id + '_' + landmark_id +
                            '_prediction_resampled.mha'),
                        output_normalization_mode=(0, 1),
                        is_single_channel=True,
                        sitk_image_output_mode='vector',
                        data_format=self.data_format,
                        image_type=np.uint8,
                        spacing=prediction_resampled_sitk[0].GetSpacing(),
                        origin=prediction_resampled_sitk[0].GetOrigin())
                bb_start = np.floor(
                    np.flip(origin / np.array(input_image.GetSpacing())))
                bb_start = np.maximum(bb_start, [0, 0, 0])
                bb_end = np.ceil(
                    np.flip(max_index / np.array(input_image.GetSpacing())))
                bb_end = np.minimum(
                    bb_end, prediction_resampled_np.shape - np.ones(3)
                )  # bb is inclusive -> subtract [1, 1, 1] from max size
                slices = tuple([
                    slice(int(bb_start[i]), int(bb_end[i] + 1))
                    for i in range(3)
                ])
                prediction_resampled_cropped_np = prediction_resampled_np[
                    slices]
                if filter_largest_cc:
                    prediction_thresh_cropped_np = (
                        prediction_resampled_cropped_np > 0.5).astype(np.uint8)
                    largest_connected_component = utils.np_image.largest_connected_component(
                        prediction_thresh_cropped_np)
                    prediction_thresh_cropped_np[largest_connected_component ==
                                                 1] = 0
                    prediction_resampled_cropped_np[
                        prediction_thresh_cropped_np == 1] = 0
                prediction_max_value_cropped_np = prediction_max_value_np[
                    slices]
                prediction_labels_cropped_np = prediction_labels_np[slices]
                prediction_max_index_np = utils.np_image.argmax(np.stack(
                    [
                        prediction_max_value_cropped_np,
                        prediction_resampled_cropped_np
                    ],
                    axis=-1),
                                                                axis=-1)
                prediction_max_index_new_np = prediction_max_index_np == 1
                prediction_max_value_cropped_np[
                    prediction_max_index_new_np] = prediction_resampled_cropped_np[
                        prediction_max_index_new_np]
                prediction_labels_cropped_np[
                    prediction_max_index_new_np] = self.landmark_mapping[int(
                        landmark_id)]
                prediction_max_value_np[
                    slices] = prediction_max_value_cropped_np
                prediction_labels_np[slices] = prediction_labels_cropped_np
                del prediction_resampled_sitk

            # delete to save memory
            del prediction_max_value_np
            prediction_labels = utils.sitk_np.np_to_sitk(prediction_labels_np)
            prediction_labels.CopyInformation(input_image)
            del prediction_labels_np
            utils.io.image.write(
                prediction_labels,
                self.output_folder_handler.path_for_iteration(
                    self.current_iter, image_id + '.mha'))
            if self.save_output_images:
                prediction_labels_resampled = utils.sitk_np.sitk_to_np(
                    utils.sitk_image.resample_to_spacing(
                        prediction_labels, [1.0, 1.0, 1.0], 'nearest'))
                prediction_labels_resampled = np.flip(
                    prediction_labels_resampled, axis=0)
                utils.io.image.write_multichannel_np(
                    prediction_labels_resampled,
                    self.output_folder_handler.path('output',
                                                    image_id + '_seg.png'),
                    channel_layout_mode='label_rgb',
                    output_normalization_mode=(0, 1),
                    image_layout_mode='max_projection',
                    is_single_channel=True,
                    sitk_image_output_mode='vector',
                    data_format=self.data_format,
                    image_type=np.uint8)
                utils.io.image.write_multichannel_np(
                    prediction_labels_resampled,
                    self.output_folder_handler.path('output',
                                                    image_id + '_seg_rgb.mha'),
                    channel_layout_mode='label_rgb',
                    output_normalization_mode=(0, 1),
                    is_single_channel=True,
                    sitk_image_output_mode='vector',
                    data_format=self.data_format,
                    image_type=np.uint8)
                input_resampled = utils.sitk_np.sitk_to_np(
                    utils.sitk_image.resample_to_spacing(
                        input_image, [1.0, 1.0, 1.0], 'linear'))
                input_resampled = np.flip(input_resampled, axis=0)
                utils.io.image.write_multichannel_np(
                    input_resampled,
                    self.output_folder_handler.path('output',
                                                    image_id + '_input.png'),
                    output_normalization_mode='min_max',
                    image_layout_mode='max_projection',
                    is_single_channel=True,
                    sitk_image_output_mode='vector',
                    data_format=self.data_format,
                    image_type=np.uint8)

            if self.has_validation_groundtruth:
                segmentation_statistics.add_labels(image_id, prediction_labels,
                                                   groundtruth)
            del prediction_labels

        # finalize loss values
        if self.has_validation_groundtruth:
            segmentation_statistics.finalize()
            dice_list = segmentation_statistics.get_metric_mean_list('dice')
            mean_dice = np.nanmean(dice_list)
            dice_list = [mean_dice] + dice_list
            summary_values = OrderedDict(list(zip(self.dice_names, dice_list)))
            self.loss_metric_logger_val.update_metrics(summary_values)

        self.loss_metric_logger_val.finalize(self.current_iter)