Esempio n. 1
0
    def build_loss(self,
                   rpn_box_pred,
                   rpn_bbox_targets,
                   rpn_cls_score,
                   rpn_labels,
                   bbox_pred,
                   bbox_targets,
                   cls_score,
                   labels,
                   mask_list=None,
                   mask_gt_list=None):

        with tf.variable_scope('build_loss'):

            if cfgs.USE_SUPERVISED_MASK:
                with tf.variable_scope("supervised_mask_loss"):
                    mask_losses = 0
                    print(mask_list)
                    print(mask_gt_list)
                    for i in range(len(mask_list)):
                        a_mask, a_mask_gt = mask_list[i], mask_gt_list[i]
                        # b, h, w, c = a_mask.shape
                        last_dim = 2 if cfgs.BINARY_MASK else cfgs.CLASS_NUM + 1
                        a_mask = tf.reshape(a_mask, shape=[-1, last_dim])
                        a_mask_gt = tf.reshape(a_mask_gt, shape=[-1])
                        a_mask_loss = tf.reduce_mean(
                            tf.nn.sparse_softmax_cross_entropy_with_logits(
                                logits=a_mask, labels=a_mask_gt))
                        mask_losses += a_mask_loss
                    self.loss_dict[
                        'mask_loss'] = mask_losses * cfgs.SUPERVISED_MASK_LOSS_WEIGHT / float(
                            len(mask_list))

            with tf.variable_scope('rpn_loss'):

                rpn_reg_loss = losses.smooth_l1_loss_rpn(
                    bbox_pred=rpn_box_pred,
                    bbox_targets=rpn_bbox_targets,
                    label=rpn_labels,
                    sigma=cfgs.RPN_SIGMA)
                rpn_select = tf.reshape(tf.where(tf.not_equal(rpn_labels, -1)),
                                        [-1])
                rpn_cls_score = tf.reshape(
                    tf.gather(rpn_cls_score, rpn_select), [-1, 2])
                rpn_labels = tf.reshape(tf.gather(rpn_labels, rpn_select),
                                        [-1])
                rpn_cls_loss = tf.reduce_mean(
                    tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=rpn_cls_score, labels=rpn_labels))

                self.loss_dict[
                    'rpn_cls_loss'] = rpn_cls_loss * cfgs.RPN_CLASSIFICATION_LOSS_WEIGHT
                self.loss_dict[
                    'rpn_reg_loss'] = rpn_reg_loss * cfgs.RPN_LOCATION_LOSS_WEIGHT

            with tf.variable_scope('FastRCNN_loss'):
                if not cfgs.FAST_RCNN_MINIBATCH_SIZE == -1:
                    reg_loss = losses.smooth_l1_loss_rcnn_r(
                        bbox_pred=bbox_pred,
                        bbox_targets=bbox_targets,
                        label=labels,
                        num_classes=cfgs.CLASS_NUM + 1,
                        sigma=cfgs.FASTRCNN_SIGMA)

                    # cls_score = tf.reshape(cls_score, [-1, cfgs.CLASS_NUM + 1])
                    # labels = tf.reshape(labels, [-1])
                    cls_loss = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            logits=cls_score,
                            labels=labels))  # beacause already sample before
                else:
                    '''
                    applying OHEM here
                    '''
                    print(20 * "@@")
                    print("@@" + 10 * " " + "TRAIN WITH OHEM ...")
                    print(20 * "@@")
                    cls_loss, reg_loss = losses.sum_ohem_loss(
                        cls_score=cls_score,
                        label=labels,
                        bbox_targets=bbox_targets,
                        bbox_pred=bbox_pred,
                        num_ohem_samples=256,
                        num_classes=cfgs.CLASS_NUM + 1,
                        sigma=cfgs.FASTRCNN_SIGMA)

                self.loss_dict[
                    'fast_cls_loss'] = cls_loss * cfgs.FAST_RCNN_CLASSIFICATION_LOSS_WEIGHT
                self.loss_dict[
                    'fast_reg_loss'] = reg_loss * cfgs.FAST_RCNN_LOCATION_LOSS_WEIGHT
    def build_loss(self, rpn_box_pred, rpn_bbox_targets, rpn_cls_score,
                   rpn_labels, bbox_pred_h, bbox_targets_h, cls_score_h,
                   bbox_pred_r, bbox_targets_r, cls_score_r, labels):
        '''

        :param rpn_box_pred: [-1, 4]
        :param rpn_bbox_targets: [-1, 4]
        :param rpn_cls_score: [-1]
        :param rpn_labels: [-1]
        :param bbox_pred_h: [-1, 4*(cls_num+1)]
        :param bbox_targets_h: [-1, 4*(cls_num+1)]
        :param cls_score_h: [-1, cls_num+1]
        :param bbox_pred_r: [-1, 5*(cls_num+1)]
        :param bbox_targets_r: [-1, 5*(cls_num+1)]
        :param cls_score_r: [-1, cls_num+1]
        :param labels: [-1]
        :return:
        '''
        with tf.variable_scope('build_loss') as sc:
            with tf.variable_scope('rpn_loss'):

                rpn_bbox_loss = losses.smooth_l1_loss_rpn(
                    bbox_pred=rpn_box_pred,
                    bbox_targets=rpn_bbox_targets,
                    label=rpn_labels,
                    sigma=cfgs.RPN_SIGMA)
                # rpn_cls_loss:
                # rpn_cls_score = tf.reshape(rpn_cls_score, [-1, 2])
                # rpn_labels = tf.reshape(rpn_labels, [-1])
                # ensure rpn_labels shape is [-1]
                rpn_select = tf.reshape(tf.where(tf.not_equal(rpn_labels, -1)),
                                        [-1])
                rpn_cls_score = tf.reshape(
                    tf.gather(rpn_cls_score, rpn_select), [-1, 2])
                rpn_labels = tf.reshape(tf.gather(rpn_labels, rpn_select),
                                        [-1])
                rpn_cls_loss = tf.reduce_mean(
                    tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=rpn_cls_score, labels=rpn_labels))

                rpn_cls_loss = rpn_cls_loss * cfgs.RPN_CLASSIFICATION_LOSS_WEIGHT
                rpn_bbox_loss = rpn_bbox_loss * cfgs.RPN_LOCATION_LOSS_WEIGHT

            with tf.variable_scope('FastRCNN_loss'):
                if not cfgs.FAST_RCNN_MINIBATCH_SIZE == -1:
                    bbox_loss_h = losses.smooth_l1_loss_rcnn_h(
                        bbox_pred=bbox_pred_h,
                        bbox_targets=bbox_targets_h,
                        label=labels,
                        num_classes=cfgs.CLASS_NUM + 1,
                        sigma=cfgs.FASTRCNN_SIGMA)

                    # cls_score = tf.reshape(cls_score, [-1, cfgs.CLASS_NUM + 1])
                    # labels = tf.reshape(labels, [-1])
                    cls_loss_h = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            logits=cls_score_h,
                            labels=labels))  # beacause already sample before

                    bbox_loss_r = losses.smooth_l1_loss_rcnn_r(
                        bbox_pred=bbox_pred_r,
                        bbox_targets=bbox_targets_r,
                        label=labels,
                        num_classes=cfgs.CLASS_NUM + 1,
                        sigma=cfgs.FASTRCNN_SIGMA)

                    cls_loss_r = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            logits=cls_score_r, labels=labels))
                else:
                    ''' 
                    applying OHEM here
                    '''
                    print(20 * "@@")
                    print("@@" + 10 * " " + "TRAIN WITH OHEM ...")
                    print(20 * "@@")
                    cls_loss = bbox_loss = losses.sum_ohem_loss(
                        cls_score=cls_score_h,
                        label=labels,
                        bbox_targets=bbox_targets_h,
                        nr_ohem_sampling=128,
                        nr_classes=cfgs.CLASS_NUM + 1)

                cls_loss_h = cls_loss_h * cfgs.FAST_RCNN_CLASSIFICATION_LOSS_WEIGHT
                bbox_loss_h = bbox_loss_h * cfgs.FAST_RCNN_LOCATION_LOSS_WEIGHT
                cls_loss_r = cls_loss_r * cfgs.FAST_RCNN_CLASSIFICATION_LOSS_WEIGHT
                bbox_loss_r = bbox_loss_r * cfgs.FAST_RCNN_LOCATION_LOSS_WEIGHT
            loss_dict = {
                'rpn_cls_loss': rpn_cls_loss,
                'rpn_loc_loss': rpn_bbox_loss,
                'fastrcnn_cls_loss_h': cls_loss_h,
                'fastrcnn_loc_loss_h': bbox_loss_h,
                'fastrcnn_cls_loss_r': cls_loss_r,
                'fastrcnn_loc_loss_r': bbox_loss_r,
            }
        return loss_dict