def metric_fn(predictions, labels, params=None):
        images = params['inputs']

        images = denormalize(images)
        pred_prob = predictions['prob']
        flat_pred_prob = tf.reshape(pred_prob, [-1])
        pred_class = tf.to_float(tf.greater_equal(pred_prob, score_threshold))

        # add summary for prediction results.
        pred_images = pred_class * 255
        vis_preds = tf.concat([
            tf.cast(images, pred_class.dtype),
            tf.tile(pred_images, multiples=[1, 1, 1, 3])
        ],
                              axis=2)
        vis_preds = tf.cast(vis_preds, tf.uint8)
        tf.summary.image('prediction', vis_preds, max_outputs=10)

        labels = tf.reshape(labels, [-1])
        pred_class = tf.reshape(pred_class, [-1])
        miou = tf.metrics.mean_iou(labels, pred_class, num_classes=2)
        p_iou = positive_iou(labels,
                             tf.reshape(pred_class, [-1]),
                             num_classes=2)

        pr_curve('eval/prc',
                 tf.cast(labels, tf.bool),
                 flat_pred_prob,
                 num_thresholds=201)
        return {'eval/miou': miou, 'eval/piou': p_iou}
    def loss_fn(predictions, labels, params=None):
        pred_logit = predictions['logit']
        labels = tf.reshape(labels, [-1])
        flat_logit = tf.reshape(pred_logit, [-1, num_classes])
        if num_classes > 1:
            onehot_labels = tf.one_hot(labels, num_classes)
        else:
            onehot_labels = tf.reshape(labels, [-1, 1])

        bpn_weights = balance_positive_negative_weight(labels,
                                                       positive_weight=1.,
                                                       negative_weight=1.)
        # bpn_weights = 1e-3 * bpn_weights
        ce_loss = tf.losses.sigmoid_cross_entropy(onehot_labels,
                                                  flat_logit,
                                                  weights=bpn_weights[:, None])
        pred_scores = tf.sigmoid(tf.reshape(flat_logit, [-1]))
        # debug
        # pred_scores = tf.Print(pred_scores, [pred_scores], message='pred_score = ')
        # dice_loss_v = dice_loss(pred_scores, labels)
        tf.summary.scalar('loss/cross_entropy_loss', ce_loss)
        # tf.summary.scalar('loss/dice_loss', dice_loss_v)
        # add metric for training
        # mean iou
        pred_class = tf.to_float(tf.greater_equal(pred_scores,
                                                  score_threshold))
        miou = tf.metrics.mean_iou(labels,
                                   tf.reshape(pred_class, [-1]),
                                   num_classes=2)
        miou_v = compute_mean_iou(None, miou[1])
        tf.identity(miou_v, 'train_miou')
        tf.summary.scalar('train/miou', miou_v)
        # postive iou
        p_iou = positive_iou(labels,
                             tf.reshape(pred_class, [-1]),
                             num_classes=2)
        p_iou_v = compute_positive_iou(None, p_iou[1])
        tf.identity(p_iou_v, 'train_piou')
        tf.summary.scalar('train/piou', p_iou_v)
        # negative iou
        n_iou = negative_iou(labels,
                             tf.reshape(pred_class, [-1]),
                             num_classes=2)
        n_iou_v = compute_negative_iou(None, n_iou[1])
        tf.identity(n_iou_v, 'train_niou')
        tf.summary.scalar('train/niou', n_iou_v)
        # add pr curve
        pr_curve('train/prc',
                 tf.cast(labels, tf.bool),
                 pred_scores,
                 num_thresholds=201)

        return ce_loss