Exemple #1
0
    def setup_training(self, session):
        """
            Apply sigmoid non-linearity to side layer ouputs + fuse layer outputs
            Compute total loss := side_layer_loss + fuse_layer_loss
            Compute predicted edge maps from fuse layer as pseudo performance metric to track
        """

        self.predictions = []
        self.loss = 0

        self.io.print_warning('Deep supervision application set to {}'.format(
            self.cfgs['deep_supervision']))

        for idx, b in enumerate(self.side_outputs):
            output = tf.nn.sigmoid(b, name='output_{}'.format(idx))
            cost = sigmoid_cross_entropy_balanced(
                b, self.edgemaps, name='cross_entropy{}'.format(idx))

            self.predictions.append(output)
            if self.cfgs['deep_supervision']:
                self.loss += (self.cfgs['loss_weights'] * cost)

        fuse_output = tf.nn.sigmoid(self.fuse, name='fuse')
        fuse_cost = sigmoid_cross_entropy_balanced(self.fuse,
                                                   self.edgemaps,
                                                   name='cross_entropy_fuse')

        self.predictions.append(fuse_output)
        self.loss += (self.cfgs['loss_weights'] * fuse_cost)

        pred = tf.cast(tf.greater(fuse_output, 0.5),
                       tf.int32,
                       name='predictions')
        error = tf.cast(tf.not_equal(pred, tf.cast(self.edgemaps, tf.int32)),
                        tf.float32)
        self.error = tf.reduce_mean(error, name='pixel_error')

        tf.summary.scalar('loss', self.loss)
        tf.summary.scalar('error', self.error)

        self.merged_summary = tf.summary.merge_all()

        self.train_writer = tf.summary.FileWriter(
            self.cfgs['save_dir'] + '/train', session.graph)
        self.val_writer = tf.summary.FileWriter(self.cfgs['save_dir'] + '/val')
    def _loss_def(self, loss_type, vgg_fmaps=None, vgg_weights=None):
        """
        wrapper function that sets the loss function to be used during training
        :param loss_type: a string defining the loss type, below are the eligible loss types per model
                      'UNET': ['bce', 'bce-topo']
                     'iUNET': ['i-bce', 'i-bce-equal', 'i-bce-topo', 'i-bce-topo-equal']
                     'SHN': ['s-bce', 's-bce-topo']
        :param vgg_fmaps: list of the vgg feature maps used for the perceptual loss
        :param vgg_weights: list of weights signifying the importance of each vgg feature map in the loss
        :return:
        """
        assert (loss_type in self.eligible_loss_types_per_model[self.name]),\
            'loss_type: [{}] not eligible for model: [{}] eligible losses for it are:[{}]'.format(loss_type, self.name, self.eligible_loss_types_per_model[self.name])

        iteration_weighing = 'equal' if 'equal' in loss_type else 'increasing'
        # iteration_weighing: if 'equal' then iUNET loss terms for each intermediate output is weighed equally else
        # uses gradually increasing weights
        print('[1]: Loss definition')
        print(
            'loss type [{}], iteration_weighing [{}], vgg_fmaps [{}], vgg_weights [{}]'
            .format(loss_type, iteration_weighing, vgg_fmaps, vgg_weights))
        self.loss_type = loss_type
        with tf.name_scope('loss'):
            if loss_type == 'bce':
                # balanced binary cross entropy (not for stacked (SHN) nor for iterative models (iUNET)
                self.loss = sigmoid_cross_entropy_balanced(self.logits,
                                                           self.y,
                                                           name='bce')
                self.bce_raw_summary = tf.summary.scalar('bce', self.loss)
                self.loss_summaries = [self.bce_raw_summary]
                # return loss, [bce_raw_summary]
                # summary is placed in a list because it is just a tensor --> error in merge
            elif loss_type == 'bce-topo':
                # bce + topological loss (not for stacked nor for iterative models)
                self.loss_bce = sigmoid_cross_entropy_balanced(self.logits,
                                                               self.y,
                                                               name='bce')
                self.loss_topo = perceptual_loss(tf.nn.sigmoid(self.logits),
                                                 self.y, vgg_fmaps,
                                                 vgg_weights)
                self.bce_raw_summary = tf.summary.scalar('bce', self.loss_bce)
                self.topo_raw_summary = tf.summary.scalar(
                    'topo', self.loss_topo)
                self.loss = self.loss_bce + self.loss_topo
                self.loss_summaries = [
                    self.bce_raw_summary, self.topo_raw_summary
                ]

            elif loss_type == 'i-bce' or loss_type == 'i-bce-equal':
                # iterative loss with bce terms (only for iUNET)
                self.loss, self.loss_summaries = iterative_loss(
                    self.sigmoided_logits,
                    self.logits,
                    self.y,
                    n_iterations=self.n_iterations,
                    iteration_weighing=iteration_weighing)

            if loss_type == 'i-bce-topo' or loss_type == 'i-bce-topo-equal':
                # iterative loss with bce + topological (only for iUNET)
                self.loss, self.loss_summaries = iterative_loss(
                    self.sigmoided_logits,
                    self.logits,
                    self.y,
                    n_iterations=self.n_iterations,
                    iteration_weighing=iteration_weighing,
                    use_vgg_loss=True,
                    vgg_fmaps=vgg_fmaps,
                    vgg_weights=vgg_weights)

            if loss_type == 's-bce':
                # iterative loss with bce (only for SHN)
                # each intermediate output's term is weighed equally i.e iteration_weighing='equal'
                self.loss, self.loss_summaries = iterative_loss(
                    self.sigmoided_logits,
                    self.logits,
                    self.y,
                    n_iterations=self.n_modules,
                    iteration_weighing='equal')

            if loss_type == 's-bce-topo':
                # iterative loss with bce + topological (only for SHN)
                # each intermediate output's term is weighed equally i.e iteration_weighing='equal'
                self.loss, self.loss_summaries = iterative_loss(
                    self.sigmoided_logits,
                    self.logits,
                    self.y,
                    n_iterations=self.n_modules,
                    iteration_weighing=iteration_weighing,
                    use_vgg_loss=True,
                    vgg_fmaps=vgg_fmaps,
                    vgg_weights=vgg_weights)