Example #1
0
    def init_networks(self):
        network_image_size = list(reversed(self.image_size))

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([('image',
                                                   [1] + network_image_size)])
        else:
            data_generator_entries = OrderedDict([('image',
                                                   network_image_size + [1])])

        data_generator_types = {'image': tf.float32}

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build val graph
        self.data_val = create_placeholders_tuple(
            data_generator_entries,
            data_types=data_generator_types,
            shape_prefix=[1])
        self.prediction_val, self.local_prediction_val, self.spatial_prediction_val = training_net(
            self.data_val,
            num_labels=self.num_labels,
            is_training=False,
            actual_network=self.unet,
            padding=self.padding,
            data_format=self.data_format,
            **self.network_parameters)
Example #2
0
    def init_networks(self):
        """
        Initialize networks and placeholders.
        """
        network_image_size = list(reversed(self.image_size))

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([('image', [1] + network_image_size),
                                                  ('single_label', [self.num_labels] + network_image_size),
                                                  ('single_heatmap', [1] + network_image_size)])
        else:
            data_generator_entries = OrderedDict([('image', network_image_size + [1]),
                                                  ('single_label', network_image_size + [self.num_labels]),
                                                  ('single_heatmap', network_image_size + [1])])

        data_generator_types = {'image':  tf.float32,
                                'labels': tf.uint8}

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGenerator(coord=self.coord, dataset=self.dataset_train, data_names_and_shapes=data_generator_entries, data_types=data_generator_types, batch_size=self.batch_size)
        data, mask, single_heatmap = self.train_queue.dequeue()
        data_heatmap_concat = tf.concat([data, single_heatmap], axis=1)
        prediction = training_net(data_heatmap_concat, num_labels=self.num_labels, is_training=True, actual_network=self.unet, padding=self.padding, **self.network_parameters)
        # losses
        self.loss_net = self.loss_function(labels=mask, logits=prediction, data_format=self.data_format)
        self.loss_reg = get_reg_loss(self.reg_constant)
        self.loss = self.loss_net + self.loss_reg

        # solver
        global_step = tf.Variable(self.current_iter, trainable=False)
        learning_rate = tf.train.piecewise_constant(global_step, self.learning_rate_boundaries, self.learning_rates)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        unclipped_gradients, variables = zip(*optimizer.compute_gradients(self.loss))
        norm = tf.global_norm(unclipped_gradients)
        if self.clip_gradient_global_norm > 0:
            gradients, _ = tf.clip_by_global_norm(unclipped_gradients, self.clip_gradient_global_norm)
        else:
            gradients = unclipped_gradients
        self.optimizer = optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)
        self.train_losses = OrderedDict([('loss', self.loss_net), ('loss_reg', self.loss_reg), ('gradient_norm', norm)])

        # build val graph
        self.data_val, self.mask_val, self.single_heatmap_val = create_placeholders_tuple(data_generator_entries, data_types=data_generator_types, shape_prefix=[1])
        self.data_heatmap_concat_val = tf.concat([self.data_val, self.single_heatmap_val], axis=1)
        self.prediction_val = training_net(self.data_heatmap_concat_val, num_labels=self.num_labels, is_training=False, actual_network=self.unet, padding=self.padding, **self.network_parameters)
        self.prediction_softmax_val = tf.nn.sigmoid(self.prediction_val)

        if self.has_validation_groundtruth:
            self.loss_val = self.loss_function(labels=self.mask_val, logits=self.prediction_val, data_format=self.data_format)
            self.val_losses = OrderedDict([('loss', self.loss_val), ('loss_reg', self.loss_reg), ('gradient_norm', tf.constant(0, tf.float32))])
    def init_networks(self):
        """
        Initialize networks and placeholders.
        """
        network_image_size = list(reversed(self.image_size))

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([('image', [1] + network_image_size),
                                                  ('spine_heatmap', [1] + network_image_size)])
        else:
            data_generator_entries = OrderedDict([('image', network_image_size + [1]),
                                                  ('spine_heatmap', [1] + network_image_size)])

        data_generator_types = {'image': tf.float32,
                                'spine_heatmap': tf.float32}


        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGenerator(coord=self.coord, dataset=self.dataset_train, data_names_and_shapes=data_generator_entries, data_types=data_generator_types, batch_size=self.batch_size)
        data, target_spine_heatmap = self.train_queue.dequeue()

        prediction = training_net(data, num_labels=self.num_labels, is_training=True, actual_network=self.unet, padding=self.padding, **self.network_parameters)
        self.loss_net = self.loss_function(target=target_spine_heatmap, pred=prediction)
        self.loss_reg = get_reg_loss(self.reg_constant)
        self.loss = self.loss_net + tf.cast(self.loss_reg, tf.float32)

        # solver
        global_step = tf.Variable(self.current_iter, trainable=False)
        learning_rate = tf.train.piecewise_constant(global_step, self.learning_rate_boundaries, self.learning_rates)
        self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss, global_step=global_step)
        self.train_losses = OrderedDict([('loss', self.loss_net), ('loss_reg', self.loss_reg)])

        # build val graph
        self.data_val, self.target_spine_heatmap_val = create_placeholders_tuple(data_generator_entries, data_types=data_generator_types, shape_prefix=[1])
        self.prediction_val = training_net(self.data_val, num_labels=self.num_labels, is_training=False, actual_network=self.unet, padding=self.padding, **self.network_parameters)

        if self.has_validation_groundtruth:
            self.loss_val = self.loss_function(target=self.target_spine_heatmap_val, pred=self.prediction_val)
            self.val_losses = OrderedDict([('loss', self.loss_val), ('loss_reg', self.loss_reg)])
Example #4
0
    def init_networks(self):
        """
        Initialize networks and placeholders.
        """
        network_image_size = list(reversed(self.image_size))

        if self.data_format == 'channels_first':
            data_generator_entries = OrderedDict([
                ('image', [1] + network_image_size),
                ('landmarks', [self.num_landmarks, 4]),
                ('landmark_mask', [1] + network_image_size)
            ])
        else:
            data_generator_entries = OrderedDict([
                ('image', network_image_size + [1]),
                ('landmarks', [self.num_landmarks, 4]),
                ('landmark_mask', network_image_size + [1])
            ])

        data_generator_types = {'image': tf.float32}

        # create sigmas variable
        sigmas = tf.get_variable('sigmas', [self.num_landmarks],
                                 initializer=tf.constant_initializer(
                                     self.heatmap_sigma))
        if not self.learnable_sigma:
            sigmas = tf.stop_gradient(sigmas)
        mean_sigmas = tf.reduce_mean(sigmas)

        # create model with shared weights between train and val
        training_net = tf.make_template('net', self.network)

        # build train graph
        self.train_queue = DataGenerator(
            coord=self.coord,
            dataset=self.dataset_train,
            data_names_and_shapes=data_generator_entries,
            data_types=data_generator_types,
            batch_size=self.batch_size)
        data, target_landmarks, landmark_mask = self.train_queue.dequeue()
        target_heatmaps = generate_heatmap_target(list(
            reversed(self.heatmap_size)),
                                                  target_landmarks,
                                                  sigmas,
                                                  scale=self.sigma_scale,
                                                  normalize=True,
                                                  data_format=self.data_format)
        prediction, local_prediction, spatial_prediction = training_net(
            data,
            num_labels=self.num_landmarks,
            is_training=True,
            actual_network=self.unet,
            padding=self.padding,
            **self.network_parameters)
        # losses
        self.loss_net = self.loss_function(target=target_heatmaps,
                                           pred=prediction,
                                           mask=landmark_mask)
        self.loss_sigmas = self.loss_function_sigmas(sigmas,
                                                     target_landmarks[0, :, 0])
        self.loss_reg = get_reg_loss(self.reg_constant)
        self.loss = self.loss_net + self.loss_reg + self.loss_sigmas

        # solver
        global_step = tf.Variable(self.current_iter, trainable=False)
        learning_rate = tf.train.piecewise_constant(
            global_step, self.learning_rate_boundaries, self.learning_rates)
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=0.99,
                                               use_nesterov=True)
        unclipped_gradients, variables = zip(
            *optimizer.compute_gradients(self.loss))
        norm = tf.global_norm(unclipped_gradients)
        if self.clip_gradient_global_norm > 0:
            gradients, _ = tf.clip_by_global_norm(
                unclipped_gradients, self.clip_gradient_global_norm)
        else:
            gradients = unclipped_gradients
        self.optimizer = optimizer.apply_gradients(zip(gradients, variables),
                                                   global_step=global_step)
        self.train_losses = OrderedDict([('loss', self.loss_net),
                                         ('loss_reg', self.loss_reg),
                                         ('loss_sigmas', self.loss_sigmas),
                                         ('mean_sigmas', mean_sigmas),
                                         ('gradient_norm', norm)])

        # build val graph
        self.data_val, self.target_landmarks_val, self.landmark_mask_val = create_placeholders_tuple(
            data_generator_entries,
            data_types=data_generator_types,
            shape_prefix=[1])
        self.target_heatmaps_val = generate_heatmap_target(
            list(reversed(self.heatmap_size)),
            self.target_landmarks_val,
            sigmas,
            scale=self.sigma_scale,
            normalize=True,
            data_format=self.data_format)
        self.prediction_val, self.local_prediction_val, self.spatial_prediction_val = training_net(
            self.data_val,
            num_labels=self.num_landmarks,
            is_training=False,
            actual_network=self.unet,
            padding=self.padding,
            **self.network_parameters)

        if self.has_validation_groundtruth:
            self.loss_val = self.loss_function(target=self.target_heatmaps_val,
                                               pred=self.prediction_val)
            self.val_losses = OrderedDict([
                ('loss', self.loss_val), ('loss_reg', self.loss_reg),
                ('loss_sigmas', tf.constant(0, tf.float32)),
                ('mean_sigmas', tf.constant(0, tf.float32)),
                ('gradient_norm', tf.constant(0, tf.float32))
            ])