Ejemplo n.º 1
0
def loss(logit_fh, logit_line, gt_fh, gt_line, user_stroke, scope='loss'):
    with tf.name_scope(scope) as _:
        # stitching face heat map (l2)
        gt_fh = slice_tensor(gt_fh, logit_fh)
        fh_loss = tf.losses.mean_squared_error(gt_fh, logit_fh)
        real_fh_loss = tf.losses.absolute_difference(gt_fh, logit_fh)

        # regression loss (l2)
        img_shape = tf.shape(logit_fh)
        N = img_shape[0]
        H = img_shape[1]
        W = img_shape[2]
        C = img_shape[3]
        full_one = tf.ones([N, H, W, C], tf.float32)
        stroke_mask = full_one - user_stroke
        logit_line = logit_line * stroke_mask
        line_loss = tf.losses.mean_squared_error(gt_line,
                                                 logit_line,
                                                 weights=stroke_mask)
        real_line_loss = tf.losses.absolute_difference(gt_line,
                                                       logit_line,
                                                       weights=stroke_mask)

        # total loss
        total_loss = fh_loss + line_loss

        return total_loss, fh_loss, real_fh_loss, line_loss, real_line_loss
Ejemplo n.º 2
0
def loss(logit_fh, logit_curve, fh, curve, user_stroke, scope='loss'):
    with tf.name_scope(scope) as _:
        # stitching face heat map (l2)
        gt_fh = slice_tensor(fh, logit_fh)
        fh_loss = tf.losses.mean_squared_error(gt_fh, logit_fh)
        real_fh_loss = tf.losses.absolute_difference(gt_fh, logit_fh)

        # curve loss (l2)
        curve_mask = slice_tensor(curve, logit_curve)
        img_shape = tf.shape(logit_fh)
        N = img_shape[0]
        H = img_shape[1]
        W = img_shape[2]
        C = img_shape[3]
        full_one = tf.ones([N, H, W, C], tf.float32)
        full_zero = tf.zeros([N, H, W, C], tf.float32)
        stroke_mask = full_one - user_stroke
        diff_mask = stroke_mask - curve_mask

        fg_pred = logit_curve * curve_mask
        bg_pred = logit_curve * diff_mask
        curve_fg_sum = tf.reduce_sum(tf.pow(curve_mask - fg_pred, 2),
                                     axis=[1, 2, 3])  # offset curve loss sum
        curve_bg_sum = tf.reduce_sum(tf.pow(full_zero - bg_pred, 2),
                                     axis=[1, 2,
                                           3])  # orthogonal curve loss sum
        nb_stroke_pixels = tf.reduce_sum(stroke_mask,
                                         axis=[1, 2,
                                               3])  # number of stroke pixels
        curve_loss = (curve_fg_sum +
                      curve_bg_sum) / nb_stroke_pixels  # mean loss
        c_loss = tf.reduce_mean(curve_loss)

        real_curve_fg_sum = tf.reduce_sum(tf.abs(curve_mask - fg_pred),
                                          axis=[1, 2, 3])
        real_curve_bg_sum = tf.reduce_sum(tf.abs(full_zero - bg_pred),
                                          axis=[1, 2, 3])
        real_curve_loss = (real_curve_fg_sum +
                           real_curve_bg_sum) / nb_stroke_pixels
        real_c_loss = tf.reduce_mean(real_curve_loss)

        total_loss = fh_loss + c_loss

        return total_loss, fh_loss, c_loss, real_fh_loss, real_c_loss