Exemplo n.º 1
0
    def build_cls_loss(self, batch_size, pos_mask, neg_mask, pos_mask_flatten, neg_mask_flatten, n_pos, do_summary, pixel_cls_loss_weight_lambda):
        from OHEM import OHNM_batch
        from losses import loss_with_binary_dice
        with tf.name_scope('pixel_cls_loss'):
            def no_pos():
                return tf.constant(.0)

            def has_pos():
                print('the pixel_cls_logits_flatten is ', self.pixel_cls_logits_flatten)
                print('the pos_mask is ', pos_mask_flatten)
                pixel_cls_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=self.pixel_cls_logits_flatten,
                    labels=tf.cast(pos_mask_flatten, dtype=tf.int32))

                # pixel_neg_scores = self.pixel_cls_scores_flatten[:, :, 0]
                # selected_neg_pixel_mask = OHNM_batch(batch_size, pixel_neg_scores, pos_mask_flatten, neg_mask_flatten)
                #
                # cur_pixel_cls_weights = tf.cast(selected_neg_pixel_mask, tf.float32)
                # n_neg = tf.cast(tf.reduce_sum(selected_neg_pixel_mask), tf.float32)
                # loss = tf.reduce_sum(pixel_cls_loss * cur_pixel_cls_weights) / (n_neg + n_pos)

                # return loss
                return tf.reduce_mean(pixel_cls_loss)

            pixel_cls_loss = has_pos()
            pixel_cls_dice_loss, pixel_cls_dice = loss_with_binary_dice(self.pixel_cls_scores, pos_mask, axis=[1, 2])
            tf.add_to_collection(tf.GraphKeys.LOSSES, pixel_cls_loss * pixel_cls_loss_weight_lambda)
            tf.add_to_collection(tf.GraphKeys.LOSSES, pixel_cls_dice_loss * pixel_cls_loss_weight_lambda * dice_coff)
        return pixel_cls_loss, pixel_cls_dice, pixel_cls_dice_loss
Exemplo n.º 2
0
 def build_cls_loss(self, batch_size, pos_mask, neg_mask, pos_mask_flatten,
                    neg_mask_flatten, n_pos, do_summary,
                    pixel_cls_loss_weight_lambda):
     from losses import loss_with_binary_dice
     with tf.name_scope('pixel_cls_loss'):
         pixel_cls_loss = tf.reduce_mean(
             tf.nn.sparse_softmax_cross_entropy_with_logits(
                 logits=self.pixel_cls_logits_flatten,
                 labels=tf.cast(pos_mask_flatten, dtype=tf.int32)))
         pixel_cls_dice_loss, pixel_cls_dice = loss_with_binary_dice(
             self.pixel_cls_scores, pos_mask, axis=[1, 2])
         tf.add_to_collection(tf.GraphKeys.LOSSES,
                              pixel_cls_loss * pixel_cls_loss_weight_lambda)
         tf.add_to_collection(
             tf.GraphKeys.LOSSES,
             pixel_cls_dice_loss * pixel_cls_loss_weight_lambda * dice_coff)
     return pixel_cls_loss, pixel_cls_dice, pixel_cls_dice_loss
Exemplo n.º 3
0
    def build_cls_loss(self, batch_size, pos_mask, neg_mask, pos_mask_flatten,
                       neg_mask_flatten, n_pos, do_summary,
                       pixel_cls_loss_weight_lambda):
        from losses import loss_with_binary_dice
        pixel_cls_weight_flatten = tf.reshape(self.pixel_cls_weight,
                                              [batch_size, -1])
        with tf.name_scope('pixel_cls_loss'):

            def no_pos():
                return tf.constant(.0)

            def has_pos():
                print('the pixel_cls_logits_flatten is ',
                      self.pixel_cls_logits_flatten)
                print('the pos_mask is ', pos_mask_flatten)
                pixel_cls_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=self.pixel_cls_logits_flatten,
                    labels=tf.cast(pos_mask_flatten, dtype=tf.int32))

                # return tf.reduce_mean(pixel_cls_loss)
                pixel_neg_scores = self.pixel_cls_scores_flatten[:, :, 0]
                shape = pos_mask.get_shape().as_list()
                selected_neg_pixel_mask = OHNM_batch(batch_size,
                                                     pixel_neg_scores,
                                                     pos_mask_flatten,
                                                     neg_mask_flatten)

                tf.summary.image('selected_neg_mask_cls',
                                 tf.expand_dims(tf.cast(
                                     tf.reshape(selected_neg_pixel_mask,
                                                tf.shape(pos_mask)),
                                     tf.float32),
                                                axis=3),
                                 max_outputs=1)
                n_neg = tf.cast(tf.reduce_sum(selected_neg_pixel_mask),
                                tf.float32)
                selected_neg_pixel_mask = tf.cond(
                    tf.equal(n_neg, 0), lambda: neg_mask_flatten,
                    lambda: tf.cast(selected_neg_pixel_mask, tf.bool))
                n_neg = tf.cast(
                    tf.reduce_sum(tf.cast(selected_neg_pixel_mask,
                                          tf.float32)), tf.float32)

                cur_pixel_cls_weights = tf.cast(pos_mask_flatten, tf.float32) + \
                                        tf.cast(selected_neg_pixel_mask, tf.float32)
                # instance balanced weight
                # cur_pixel_cls_weights = tf.cast(pixel_cls_weight_flatten, tf.float32) + \
                #                         tf.cast(selected_neg_pixel_mask, tf.float32)
                loss = tf.reduce_sum(
                    pixel_cls_loss * cur_pixel_cls_weights) / (n_neg + n_pos)

                # shape = pos_mask.get_shape().as_list()
                # pixel_neg_scores = tf.abs(self.pixel_cls_scores_flatten[:, :, 0] - 0.5)
                # pixel_neg_loss = tf.cast(pixel_cls_loss * tf.cast(neg_mask_flatten, tf.float32), tf.float32)
                # selected_neg_pixel_mask = OHNM_batch(batch_size, pixel_neg_loss, pos_mask_flatten, neg_mask_flatten)
                # tf.summary.image('selected_neg_mask_cls',
                #                  tf.cast(tf.reshape(selected_neg_pixel_mask, [shape[0], shape[1], shape[2], 1]),
                #                          tf.float32), max_outputs=1)
                # cur_pixel_cls_weights = tf.cast(selected_neg_pixel_mask, tf.float32) + tf.cast(pos_mask_flatten,
                #                                                                                tf.float32)
                # # 排除可能有零的情况
                # loss = tf.cond(tf.equal(tf.reduce_sum(cur_pixel_cls_weights), 0), lambda: 0.0,
                #                lambda: tf.reduce_sum(pixel_cls_loss * cur_pixel_cls_weights) / tf.reduce_sum(
                #                    cur_pixel_cls_weights))

                return loss

            pixel_cls_loss = has_pos()
            pixel_cls_dice_loss, pixel_cls_dice = loss_with_binary_dice(
                self.pixel_cls_scores, pos_mask, axis=[1, 2])
            tf.add_to_collection(tf.GraphKeys.LOSSES,
                                 pixel_cls_loss * pixel_cls_loss_weight_lambda)
            tf.add_to_collection(
                tf.GraphKeys.LOSSES,
                pixel_cls_dice_loss * pixel_cls_loss_weight_lambda * dice_coff)
        return pixel_cls_loss, pixel_cls_dice, pixel_cls_dice_loss