def dice_coeff_forward(pred, target, axis=None, axis_img=(1, 2), threshold=0.2): """ excact dice coeff. assume target value is zero or ones :param pred: [batch, h, w, 1] :param target: [batch, h, w, 1] :param axis: :param threshold: pixel < 0.2 then pxiel <= 0 :return: [batch x channel] """ # assert axis is not None # pred = tf.select(tf.less_equal(pred, threshold), tf.zeros(shape=pred.dims), tf.ones(shape=pred.dims)) pred = tf.where(tf.less_equal(pred, threshold), tf.zeros(shape=pred.dims), tf.ones(shape=pred.dims)) ps = pred.sum(axis=axis_img) ts = target.sum(axis=axis_img) intersect = tf.sum(pred * target, axis=axis_img) # gt nomask and exact pred imask = tf.any(pred, axis=axis_img) nomask_score = 1. - tf.any(target, axis=axis_img).to_float() # dice_batch = tf.select(imask, 2. * intersect / (ps + ts), nomask_score) dice_batch = tf.where(imask, 2. * intersect / (ps + ts), nomask_score) # make wanted reduced_axis if axis is None: return dice_batch.mean() else: return dice_batch.mean(axis)
def log_dice_coeff(pred, target, axis=(1, 2), eps=1e-8): ps = pred.sum(axis=axis, keepdims=False) ts = target.sum(axis=axis, keepdims=False) intersect = tf.sum(pred * target, axis=axis, keepdims=False) return 2. * tf.log(intersect + eps) - tf.log(ps + ts + eps)
def jaccard_index(pred, target, axis=None, eps=1e-8): assert pred.ndim == target.ndim axis = axis or list(range(1, pred.ndim - 1)) ps = pred.sum(axis=axis) ts = target.sum(axis=axis) intersect = tf.sum(pred * target, axis=axis) return (intersect + eps) / (ps + ts - intersect + eps)
def sparsemax_loss(logits, sparsemax, labels, axis=-1, name=None): # https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py # cf) tf.contrib.sparsemax.sparsemax_loss(logits, sparsemax, labels, name=name) shifted_logits = logits - tf.mean(logits, axis=axis, keepdims=True) # sum over support support = tf.cast(sparsemax > 0, sparsemax.dtype) sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax) # - z_k + ||q||^2 q_part = labels * (0.5 * labels - shifted_logits) return tf.sum(sum_s + q_part, axis=axis)
def dice_coeff(pred, target, axis=None, eps=1e-8): """ dice coeff, ndim == 4, nhwc format 2*(intersect) / (pred + target) :param pred: assume : range(0,1), 4dim :param target: 0 or 1 nhwc format 4dim :param axis: reduction_axis :param eps: for prevent NaN :return: [batch x channel] """ assert pred.ndim == target.ndim axis = axis or list(range(1, pred.ndim - 1)) ps = pred.sum(axis=axis) ts = target.sum(axis=axis) intersect = tf.sum(pred * target, axis=axis) return (2. * intersect + eps) / (ps + ts + eps)
def sparsemax(logits, axis=-1, name=None): """ :param logits: tf.Tensor :param axis: :param name: :return: """ # https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/contrib/sparsemax/python/ops/sparsemax.py logits = tf.shiftdim(logits, -axis - 1) # lshape = logits.shape tshape = tf.shape(logits) dims = tshape[axis] logits = tf.reshape(logits, (-1, dims)) obs = tf.shape(logits)[0] # sort z z = logits - tf.mean(logits, axis=1, keepdims=True) z_sorted = tf.sort(z) # calculate k(z) z_cumsum = tf.cumsum(z_sorted, axis=1) k = tf.range(1, dims + 1).astype(dtype=logits.dtype) z_check = 1 + k * z_sorted > z_cumsum k_z = tf.sum(z_check.astype(tf.int32), axis=1) # calculate tau(z) indices = tf.stack([tf.range(0, obs), k_z - 1], axis=1) tau_sum = tf.gather_nd(z_cumsum, indices) tau_z = (tau_sum - 1) / k_z.astype(logits.dtype) res = tf.maximum(tf.cast(0, logits.dtype), z - tau_z[:, tf.newaxis]) # rotate axis res = tf.reshape(res, tshape) # res.set_shape(lshape) res = tf.shiftdim(res, axis + 1) return res
def dice_coeff_sorensen(pred, target, axis=None, eps=1e-8): """ Sorensen-Dice loss https://arxiv.org/pdf/1606.04797.pdf dice coeff, ndim == 4, nhwc format 2*(intersect) / (pred + target) :param pred: assume : range(0,1), 4dim :param target: 0 or 1 nhwc format 4dim :param axis: reduction_axis :param eps: for prevent NaN :return: [batch x channel] """ assert pred.ndim == target.ndim axis = axis or list(range(1, pred.ndim - 1)) ps = tf.square(pred).sum(axis=axis) ts = tf.square(target).sum(axis=axis) intersect = tf.sum(pred * target, axis=axis) return (2. * intersect + eps) / (ps + ts + eps)
def dice_coeff_digit(pred, target, axis=None, axis_img=(1, 2), threshold=0.2): """ excact dice coeff. assume target value is zero or ones :return: [batch x channel] """ ps = pred.sum(axis=axis_img) ts = target.sum(axis=axis_img) intersect = tf.sum(pred * target, axis=axis_img) # gt nomask and exact pred inomask = tf.less_equal(pred.max(axis=axis_img), threshold) nomask_score = 1. - tf.any(target, axis=axis_img).to_float() # dice_batch = tf.select(inomask, nomask_score, 2.*intersect / (ps + ts)) dice_batch = tf.where(inomask, nomask_score, 2. * intersect / (ps + ts)) # make wanted reduced_axis if axis is None: return dice_batch.mean() else: return dice_batch
def sampling_xy_3r(img, xys, outsize=None, oob=None): """ differentiable image sampling (with interpolation) :param img: source image [HWC] :param xys: source coord [2, H'*W'] if outsize given :param outsize: [H',W'] or None, xys must has rank3 :return: [B,H',W',C] """ assert img.ndim == 3 oobv = oob if oobv is None: # oobv = tf.zeros(shape=(img.dims[-1]), dtype=tf.float32) # [0., 0., 0.] oobv = 0. # oobv = [0., 0., 0.] oobv = tf.convert_to_tensor(oobv) if outsize is None: outsize = tf.shape(xys)[1:] xys = xys.flat2d() H, W, C = img.shapes WH = tf.stack([W, H]).to_float().reshape((2, 1)) # XYf = (xys + 1.) * WH * 0.5 # scale to HW coord ( + 1 for start from 0) XYf = (xys + 0.5) * WH # * 0.5 # scale to HW coord ( + 1 for start from 0) XYS = tf.ceil(XYf) # left top weight # prepare weights w00 = XYS - XYf # [2, p] w11 = 1. - w00 # [2, p] # get near 4 pixels per pixel XYS = XYS.to_int32() # [2, p] # todo check xy order XYs = XYS - 1 Xs = tf.stack([XYs[0], XYS[0]]) Ys = tf.stack([XYs[1], XYS[1]]) # get mask of outof bound # leave option for filling value Xi = Xs.clip_by_value(0, W - 1) Yi = Ys.clip_by_value(0, H - 1) inb = tf.logical_and(Xi.equal(Xs), Yi.equal(Ys)) # [2, p] inb = tf.reduce_any(inb, axis=0, keepdims=True) # all oob? [1, p]- # inb = inb.expand_dims(2).to_float() # [1, p] inb = inb.reshape((-1, 1)).to_float() # [p, 1] 1 for channel # get 4 pixels [p, C] p00 = getpixel(img, tf.stack([Yi[0], Xi[0]]).T) p01 = getpixel(img, tf.stack([Yi[0], Xi[1]]).T) p10 = getpixel(img, tf.stack([Yi[1], Xi[0]]).T) p11 = getpixel(img, tf.stack([Yi[1], Xi[1]]).T) # stacked nearest : [4, p, C] near4 = tf.stack([p00, p01, p10, p11], axis=0) # XYw : 4 near point weights [4, pixel] w4 = tf.stack([ w00[1] * w00[0], # left top w00[1] * w11[0], # right top w11[1] * w00[0], # left bottom w11[1] * w11[0] ]) # right bottom # weighted sum of 4 nearest pixels broadcasting w4 = w4.reshape((4, -1, 1)) # interpolated = tf.sum(w4 * near4.to_float(), axis=1) # [p, C] interpolated = tf.sum(w4 * near4.to_float(), axis=0) # [p, C] # assign oob value # fill oob by broadcasting oobv = oobv.reshape((1, -1)) # [p, C] interpolated = interpolated * inb + oobv * (1. - inb) output = interpolated.reshape((outsize[0], outsize[1], C)) # reshape [p, C] => [H', W', C] return output