def create_bce_loss(output, label, num_classes, ignore_label): raw_pred = tf.reshape(output, [-1, num_classes]) label = prepare_label(label, tf.stack(output.get_shape()[1:3]), num_classes=num_classes, one_hot=False) label = tf.reshape(label, [ -1, ]) indices = get_mask(label, num_classes, ignore_label) gt = tf.cast(tf.gather(label, indices), tf.int32) gt_one_hot = tf.one_hot(gt, num_classes) pred = tf.gather(raw_pred, indices) BCE = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=gt_one_hot)) # inse = tf.reduce_sum(pred * tf.cast(gt_one_hot,tf.float32)) # l = tf.reduce_sum(pred) # r = tf.reduce_sum(tf.cast(gt_one_hot,tf.float32)) # dice = tf.math.log((2. * inse + 1e-5) / (l + r + 1e-5)) dice = dice_coef_theoretical(pred, gt) # tf.Print(dice) loss = BCE - tf.math.log(dice) reduced_loss = loss return reduced_loss
def create_loss(output, label, num_classes, ignore_label): raw_pred = tf.reshape(output, [-1, num_classes]) label = prepare_label(label, tf.stack(output.get_shape()[1:3]), num_classes=num_classes, one_hot=False) label = tf.reshape(label, [-1,]) indices = get_mask(label, num_classes, ignore_label) gt = tf.cast(tf.gather(label, indices), tf.int32) pred = tf.gather(raw_pred, indices) loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=gt) reduced_loss = tf.reduce_mean(loss) return reduced_loss