def setup_training(self, session): """ Apply sigmoid non-linearity to side layer ouputs + fuse layer outputs Compute total loss := side_layer_loss + fuse_layer_loss Compute predicted edge maps from fuse layer as pseudo performance metric to track """ self.predictions = [] self.loss = 0 self.io.print_warning('Deep supervision application set to {}'.format( self.cfgs['deep_supervision'])) for idx, b in enumerate(self.side_outputs): output = tf.nn.sigmoid(b, name='output_{}'.format(idx)) cost = sigmoid_cross_entropy_balanced( b, self.edgemaps, name='cross_entropy{}'.format(idx)) self.predictions.append(output) if self.cfgs['deep_supervision']: self.loss += (self.cfgs['loss_weights'] * cost) fuse_output = tf.nn.sigmoid(self.fuse, name='fuse') fuse_cost = sigmoid_cross_entropy_balanced(self.fuse, self.edgemaps, name='cross_entropy_fuse') self.predictions.append(fuse_output) self.loss += (self.cfgs['loss_weights'] * fuse_cost) pred = tf.cast(tf.greater(fuse_output, 0.5), tf.int32, name='predictions') error = tf.cast(tf.not_equal(pred, tf.cast(self.edgemaps, tf.int32)), tf.float32) self.error = tf.reduce_mean(error, name='pixel_error') tf.summary.scalar('loss', self.loss) tf.summary.scalar('error', self.error) self.merged_summary = tf.summary.merge_all() self.train_writer = tf.summary.FileWriter( self.cfgs['save_dir'] + '/train', session.graph) self.val_writer = tf.summary.FileWriter(self.cfgs['save_dir'] + '/val')
def _loss_def(self, loss_type, vgg_fmaps=None, vgg_weights=None): """ wrapper function that sets the loss function to be used during training :param loss_type: a string defining the loss type, below are the eligible loss types per model 'UNET': ['bce', 'bce-topo'] 'iUNET': ['i-bce', 'i-bce-equal', 'i-bce-topo', 'i-bce-topo-equal'] 'SHN': ['s-bce', 's-bce-topo'] :param vgg_fmaps: list of the vgg feature maps used for the perceptual loss :param vgg_weights: list of weights signifying the importance of each vgg feature map in the loss :return: """ assert (loss_type in self.eligible_loss_types_per_model[self.name]),\ 'loss_type: [{}] not eligible for model: [{}] eligible losses for it are:[{}]'.format(loss_type, self.name, self.eligible_loss_types_per_model[self.name]) iteration_weighing = 'equal' if 'equal' in loss_type else 'increasing' # iteration_weighing: if 'equal' then iUNET loss terms for each intermediate output is weighed equally else # uses gradually increasing weights print('[1]: Loss definition') print( 'loss type [{}], iteration_weighing [{}], vgg_fmaps [{}], vgg_weights [{}]' .format(loss_type, iteration_weighing, vgg_fmaps, vgg_weights)) self.loss_type = loss_type with tf.name_scope('loss'): if loss_type == 'bce': # balanced binary cross entropy (not for stacked (SHN) nor for iterative models (iUNET) self.loss = sigmoid_cross_entropy_balanced(self.logits, self.y, name='bce') self.bce_raw_summary = tf.summary.scalar('bce', self.loss) self.loss_summaries = [self.bce_raw_summary] # return loss, [bce_raw_summary] # summary is placed in a list because it is just a tensor --> error in merge elif loss_type == 'bce-topo': # bce + topological loss (not for stacked nor for iterative models) self.loss_bce = sigmoid_cross_entropy_balanced(self.logits, self.y, name='bce') self.loss_topo = perceptual_loss(tf.nn.sigmoid(self.logits), self.y, vgg_fmaps, vgg_weights) self.bce_raw_summary = tf.summary.scalar('bce', self.loss_bce) self.topo_raw_summary = tf.summary.scalar( 'topo', self.loss_topo) self.loss = self.loss_bce + self.loss_topo self.loss_summaries = [ self.bce_raw_summary, self.topo_raw_summary ] elif loss_type == 'i-bce' or loss_type == 'i-bce-equal': # iterative loss with bce terms (only for iUNET) self.loss, self.loss_summaries = iterative_loss( self.sigmoided_logits, self.logits, self.y, n_iterations=self.n_iterations, iteration_weighing=iteration_weighing) if loss_type == 'i-bce-topo' or loss_type == 'i-bce-topo-equal': # iterative loss with bce + topological (only for iUNET) self.loss, self.loss_summaries = iterative_loss( self.sigmoided_logits, self.logits, self.y, n_iterations=self.n_iterations, iteration_weighing=iteration_weighing, use_vgg_loss=True, vgg_fmaps=vgg_fmaps, vgg_weights=vgg_weights) if loss_type == 's-bce': # iterative loss with bce (only for SHN) # each intermediate output's term is weighed equally i.e iteration_weighing='equal' self.loss, self.loss_summaries = iterative_loss( self.sigmoided_logits, self.logits, self.y, n_iterations=self.n_modules, iteration_weighing='equal') if loss_type == 's-bce-topo': # iterative loss with bce + topological (only for SHN) # each intermediate output's term is weighed equally i.e iteration_weighing='equal' self.loss, self.loss_summaries = iterative_loss( self.sigmoided_logits, self.logits, self.y, n_iterations=self.n_modules, iteration_weighing=iteration_weighing, use_vgg_loss=True, vgg_fmaps=vgg_fmaps, vgg_weights=vgg_weights)