def loss(logit_fh, logit_line, gt_fh, gt_line, user_stroke, scope='loss'): with tf.name_scope(scope) as _: # stitching face heat map (l2) gt_fh = slice_tensor(gt_fh, logit_fh) fh_loss = tf.losses.mean_squared_error(gt_fh, logit_fh) real_fh_loss = tf.losses.absolute_difference(gt_fh, logit_fh) # regression loss (l2) img_shape = tf.shape(logit_fh) N = img_shape[0] H = img_shape[1] W = img_shape[2] C = img_shape[3] full_one = tf.ones([N, H, W, C], tf.float32) stroke_mask = full_one - user_stroke logit_line = logit_line * stroke_mask line_loss = tf.losses.mean_squared_error(gt_line, logit_line, weights=stroke_mask) real_line_loss = tf.losses.absolute_difference(gt_line, logit_line, weights=stroke_mask) # total loss total_loss = fh_loss + line_loss return total_loss, fh_loss, real_fh_loss, line_loss, real_line_loss
def loss(logit_fh, logit_curve, fh, curve, user_stroke, scope='loss'): with tf.name_scope(scope) as _: # stitching face heat map (l2) gt_fh = slice_tensor(fh, logit_fh) fh_loss = tf.losses.mean_squared_error(gt_fh, logit_fh) real_fh_loss = tf.losses.absolute_difference(gt_fh, logit_fh) # curve loss (l2) curve_mask = slice_tensor(curve, logit_curve) img_shape = tf.shape(logit_fh) N = img_shape[0] H = img_shape[1] W = img_shape[2] C = img_shape[3] full_one = tf.ones([N, H, W, C], tf.float32) full_zero = tf.zeros([N, H, W, C], tf.float32) stroke_mask = full_one - user_stroke diff_mask = stroke_mask - curve_mask fg_pred = logit_curve * curve_mask bg_pred = logit_curve * diff_mask curve_fg_sum = tf.reduce_sum(tf.pow(curve_mask - fg_pred, 2), axis=[1, 2, 3]) # offset curve loss sum curve_bg_sum = tf.reduce_sum(tf.pow(full_zero - bg_pred, 2), axis=[1, 2, 3]) # orthogonal curve loss sum nb_stroke_pixels = tf.reduce_sum(stroke_mask, axis=[1, 2, 3]) # number of stroke pixels curve_loss = (curve_fg_sum + curve_bg_sum) / nb_stroke_pixels # mean loss c_loss = tf.reduce_mean(curve_loss) real_curve_fg_sum = tf.reduce_sum(tf.abs(curve_mask - fg_pred), axis=[1, 2, 3]) real_curve_bg_sum = tf.reduce_sum(tf.abs(full_zero - bg_pred), axis=[1, 2, 3]) real_curve_loss = (real_curve_fg_sum + real_curve_bg_sum) / nb_stroke_pixels real_c_loss = tf.reduce_mean(real_curve_loss) total_loss = fh_loss + c_loss return total_loss, fh_loss, c_loss, real_fh_loss, real_c_loss