def build_loss(self, rpn_box_pred, rpn_bbox_targets, rpn_cls_score,
                   rpn_labels, bbox_pred, bbox_targets, cls_score, labels):
        '''

        :param rpn_box_pred: [-1, 4]
        :param rpn_bbox_targets: [-1, 4]
        :param rpn_cls_score: [-1]
        :param rpn_labels: [-1]
        :param bbox_pred: [-1, 4*(cls_num+1)]
        :param bbox_targets: [-1, 4*(cls_num+1)]
        :param cls_score: [-1, cls_num+1]
        :param labels: [-1]
        :return:
        '''
        with tf.variable_scope('build_loss') as sc:
            with tf.variable_scope('rpn_loss'):

                rpn_bbox_loss = losses.smooth_l1_loss_rpn(
                    bbox_pred=rpn_box_pred,
                    bbox_targets=rpn_bbox_targets,
                    label=rpn_labels,
                    sigma=cfgs.RPN_SIGMA)
                # rpn_cls_loss:
                # rpn_cls_score = tf.reshape(rpn_cls_score, [-1, 2])
                # rpn_labels = tf.reshape(rpn_labels, [-1])
                # ensure rpn_labels shape is [-1]
                rpn_select = tf.reshape(tf.where(tf.not_equal(rpn_labels, -1)),
                                        [-1])
                rpn_cls_score = tf.reshape(
                    tf.gather(rpn_cls_score, rpn_select), [-1, 2])
                rpn_labels = tf.reshape(tf.gather(rpn_labels, rpn_select),
                                        [-1])
                rpn_cls_loss = tf.reduce_mean(
                    tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=rpn_cls_score, labels=rpn_labels))

                rpn_cls_loss = rpn_cls_loss * cfgs.RPN_CLASSIFICATION_LOSS_WEIGHT
                rpn_bbox_loss = rpn_bbox_loss * cfgs.RPN_LOCATION_LOSS_WEIGHT

            with tf.variable_scope('FastRCNN_loss'):
                if not cfgs.FAST_RCNN_MINIBATCH_SIZE == -1:
                    bbox_loss = losses.smooth_l1_loss_rcnn(
                        bbox_pred=bbox_pred,
                        bbox_targets=bbox_targets,
                        label=labels,
                        num_classes=cfgs.CLASS_NUM + 1,
                        sigma=cfgs.FASTRCNN_SIGMA)

                    # cls_score = tf.reshape(cls_score, [-1, cfgs.CLASS_NUM + 1])
                    # labels = tf.reshape(labels, [-1])
                    cls_loss = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            logits=cls_score,
                            labels=labels))  # beacause already sample before
                else:
                    ''' 
                    applying OHEM here
                    '''
                    print(20 * "@@")
                    print("@@" + 10 * " " + "TRAIN WITH OHEM ...")
                    print(20 * "@@")
                    cls_loss = bbox_loss = losses.sum_ohem_loss(
                        cls_score=cls_score,
                        label=labels,
                        bbox_targets=bbox_targets,
                        nr_ohem_sampling=128,
                        nr_classes=cfgs.CLASS_NUM + 1)
                cls_loss = cls_loss * cfgs.FAST_RCNN_CLASSIFICATION_LOSS_WEIGHT
                bbox_loss = bbox_loss * cfgs.FAST_RCNN_LOCATION_LOSS_WEIGHT
            loss_dict = {
                'rpn_cls_loss': rpn_cls_loss,
                'rpn_loc_loss': rpn_bbox_loss,
                'fastrcnn_cls_loss': cls_loss,
                'fastrcnn_loc_loss': bbox_loss
            }
        return loss_dict
예제 #2
0
    def build_loss(self, rpn_box_pred, rpn_bbox_targets, rpn_cls_score,
                   rpn_labels, bbox_pred, bbox_targets, cls_score, labels):
        """
        loss function
        :param rpn_box_pred: [-1, 4]
        :param rpn_bbox_targets: [-1, 4]
        :param rpn_cls_score: [-1]
        :param rpn_labels: [-1]
        :param bbox_pred: [-1, 4*(cls_num+1)]
        :param bbox_targets: [-1, 4*(cls_num+1)]
        :param cls_score: [-1, cls_num+1]
        :param labels: [-1]
        :return:
        :return:
        """
        with tf.variable_scope('build_loss') as sc:
            with tf.variable_scope('rpn_loss'):

                # get bbox losses(localization loss)
                rpn_bbox_loss = losses.smooth_l1_loss_rpn(
                    bbox_pred=rpn_box_pred,
                    bbox_targets=rpn_bbox_targets,
                    labels=rpn_labels,
                    sigma=cfgs.RPN_SIGMA)
                # select foreground and background
                rpn_select = tf.reshape(tf.where(tf.not_equal(rpn_labels, -1)),
                                        shape=[-1])
                rpn_cls_score = tf.reshape(tf.gather(rpn_cls_score,
                                                     rpn_select),
                                           shape=[-1, 2])
                rpn_labels = tf.reshape(tf.gather(rpn_labels, rpn_select),
                                        shape=[-1])

                rpn_cls_loss = tf.reduce_mean(
                    tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=rpn_cls_score, labels=rpn_labels))
                #------------------------------ RPN classification and localization loss-------------------------------
                rpn_cls_loss = rpn_cls_loss * cfgs.RPN_CLASSIFICATION_LOSS_WEIGHT
                rpn_bbox_loss = rpn_bbox_loss * cfgs.RPN_LOCATION_LOSS_WEIGHT

            with tf.variable_scope('FastRCNN_loss'):
                if not cfgs.FAST_RCNN_MINIBATCH_SIZE == -1:
                    bbox_loss = losses.smooth_l1_loss_rcnn(
                        bbox_pred=bbox_pred,
                        bbox_targets=bbox_targets,
                        label=labels,
                        num_classes=cfgs.CLASS_NUM + 1,
                        sigma=cfgs.FASTRCNN_SIGMA)
                    cls_loss = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            logits=cls_score,
                            labels=labels))  # because already sample before

                else:
                    ''' 
                    applying OHEM here
                    '''
                    print("TRAIN WITH OHEM ...")
                    cls_loss, bbox_loss = losses.sum_ohem_loss(
                        cls_score=cls_score,
                        labels=labels,
                        bbox_targets=bbox_targets,
                        bbox_pred=bbox_pred,
                        num_ohem_samples=256,
                        num_classes=cfgs.CLASS_NUM + 1)

                # ----------------------- Faster RCNN classification and localization loss------------------------------
                cls_loss = cls_loss * cfgs.FAST_RCNN_CLASSIFICATION_LOSS_WEIGHT
                bbox_loss = bbox_loss * cfgs.FAST_RCNN_LOCATION_LOSS_WEIGHT

            loss_dict = {
                'rpn_cls_loss': rpn_cls_loss,
                'rpn_loc_loss': rpn_bbox_loss,
                'fastrcnn_cls_loss': cls_loss,
                'fastrcnn_loc_loss': bbox_loss
            }
        return loss_dict
예제 #3
0
    def build_loss(self, rpn_box_pred, rpn_bbox_targets, rpn_cls_score, rpn_labels,
                   bbox_pred, bbox_targets, cls_score, labels, mask_list=None, mask_gt_list=None):

        with tf.variable_scope('build_loss'):

            if cfgs.USE_SUPERVISED_MASK:
                with tf.variable_scope("supervised_mask_loss"):
                    mask_losses = 0
                    print(mask_list)
                    print(mask_gt_list)
                    for i in range(len(mask_list)):
                        a_mask, a_mask_gt = mask_list[i], mask_gt_list[i]
                        # b, h, w, c = a_mask.shape
                        last_dim = 2 if cfgs.BINARY_MASK else cfgs.CLASS_NUM + 1
                        a_mask = tf.reshape(a_mask, shape=[-1, last_dim])
                        a_mask_gt = tf.reshape(a_mask_gt, shape=[-1])
                        a_mask_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=a_mask,
                                                                                                    labels=a_mask_gt))
                        mask_losses += a_mask_loss
                    self.loss_dict['mask_loss'] = mask_losses * cfgs.SUPERVISED_MASK_LOSS_WEIGHT / float(len(mask_list))

            with tf.variable_scope('rpn_loss'):

                rpn_reg_loss = losses.smooth_l1_loss_rpn(bbox_pred=rpn_box_pred,
                                                         bbox_targets=rpn_bbox_targets,
                                                         label=rpn_labels,
                                                         sigma=cfgs.RPN_SIGMA)
                rpn_select = tf.reshape(tf.where(tf.not_equal(rpn_labels, -1)), [-1])
                rpn_cls_score = tf.reshape(tf.gather(rpn_cls_score, rpn_select), [-1, 2])
                rpn_labels = tf.reshape(tf.gather(rpn_labels, rpn_select), [-1])
                rpn_cls_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=rpn_cls_score,
                                                                                             labels=rpn_labels))

                self.loss_dict['rpn_cls_loss'] = rpn_cls_loss * cfgs.RPN_CLASSIFICATION_LOSS_WEIGHT
                self.loss_dict['rpn_reg_loss'] = rpn_reg_loss * cfgs.RPN_LOCATION_LOSS_WEIGHT

            with tf.variable_scope('FastRCNN_loss'):
                if not cfgs.FAST_RCNN_MINIBATCH_SIZE == -1:
                    reg_loss = losses.smooth_l1_loss_rcnn(bbox_pred=bbox_pred,
                                                          bbox_targets=bbox_targets,
                                                          label=labels,
                                                          num_classes=cfgs.CLASS_NUM + 1,
                                                          sigma=cfgs.FASTRCNN_SIGMA)

                    # cls_score = tf.reshape(cls_score, [-1, cfgs.CLASS_NUM + 1])
                    # labels = tf.reshape(labels, [-1])
                    cls_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=cls_score,
                        labels=labels))  # beacause already sample before
                else:
                    '''
                    applying OHEM here
                    '''
                    print(20 * "@@")
                    print("@@" + 10 * " " + "TRAIN WITH OHEM ...")
                    print(20 * "@@")
                    cls_loss, reg_loss = losses.sum_ohem_loss(cls_score=cls_score,
                                                              label=labels,
                                                              bbox_targets=bbox_targets,
                                                              bbox_pred=bbox_pred,
                                                              num_ohem_samples=256,
                                                              num_classes=cfgs.CLASS_NUM + 1,
                                                              sigma=cfgs.FASTRCNN_SIGMA)

                self.loss_dict['fast_cls_loss'] = cls_loss * cfgs.FAST_RCNN_CLASSIFICATION_LOSS_WEIGHT
                self.loss_dict['fast_reg_loss'] = reg_loss * cfgs.FAST_RCNN_LOCATION_LOSS_WEIGHT
예제 #4
0
    def build_whole_detection_network(self, input_img_batch, gtboxes_batch):

        if self.is_training:
            # ensure shape is [M, 5]
            gtboxes_batch = tf.reshape(gtboxes_batch, [-1, 5])
            gtboxes_batch = tf.cast(gtboxes_batch, tf.float32)

        img_shape = tf.shape(input_img_batch)

        # 1. build base network
        feature_stride8, feature_stride16 = self.build_base_network(
            input_img_batch)

        # feature_stride8 = tf.image.resize_bilinear(feature_stride8, [tf.shape(feature_stride8)[1] * 2,
        #                                                              tf.shape(feature_stride8)[2] * 2],
        #                                            name='upsampling_stride8')

        # 2. build rpn
        with tf.variable_scope('build_ssh',
                               regularizer=slim.l2_regularizer(
                                   cfgs.WEIGHT_DECAY)):

            ssh_max_pool = slim.max_pool2d(inputs=feature_stride16,
                                           kernel_size=[2, 2],
                                           scope='ssh_max_pool')

            cls_score_m3, box_pred_m3 = self.detection_module(
                ssh_max_pool, self.m3_num_anchors_per_location,
                'detection_module_m3')
            box_pred_m3 = tf.reshape(box_pred_m3,
                                     [-1, 4 * (cfgs.CLASS_NUM + 1)])
            cls_score_m3 = tf.reshape(cls_score_m3, [-1, (cfgs.CLASS_NUM + 1)])
            cls_prob_m3 = slim.softmax(cls_score_m3, scope='cls_prob_m3')

            cls_score_m2, box_pred_m2 = self.detection_module(
                feature_stride16, self.m2_num_anchors_per_location,
                'detection_module_m2')
            box_pred_m2 = tf.reshape(box_pred_m2,
                                     [-1, 4 * (cfgs.CLASS_NUM + 1)])
            cls_score_m2 = tf.reshape(cls_score_m2, [-1, (cfgs.CLASS_NUM + 1)])
            cls_prob_m2 = slim.softmax(cls_score_m2, scope='cls_prob_m2')

            channels_16 = feature_stride16.get_shape().as_list()[-1]
            channels_8 = feature_stride8.get_shape().as_list()[-1]
            feature8_shape = tf.shape(feature_stride8)
            conv1x1_1 = slim.conv2d(inputs=feature_stride16,
                                    num_outputs=channels_16 // 4,
                                    kernel_size=[1, 1],
                                    trainable=self.is_training,
                                    weights_initializer=cfgs.INITIALIZER,
                                    activation_fn=tf.nn.relu,
                                    scope='conv1x1_1')
            upsampling = tf.image.resize_bilinear(
                conv1x1_1, [feature8_shape[1], feature8_shape[2]],
                name='upsampling')

            conv1x1_2 = slim.conv2d(inputs=feature_stride8,
                                    num_outputs=channels_8 // 2,
                                    kernel_size=[1, 1],
                                    trainable=self.is_training,
                                    weights_initializer=cfgs.INITIALIZER,
                                    activation_fn=tf.nn.relu,
                                    scope='conv1x1_2')

            eltwise_sum = upsampling + conv1x1_2

            conv3x3 = slim.conv2d(inputs=eltwise_sum,
                                  num_outputs=channels_8 // 2,
                                  kernel_size=[3, 3],
                                  trainable=self.is_training,
                                  weights_initializer=cfgs.INITIALIZER,
                                  activation_fn=tf.nn.relu,
                                  scope='conv3x3')

            cls_score_m1, box_pred_m1 = self.detection_module(
                conv3x3, self.m1_num_anchors_per_location,
                'detection_module_m1')
            box_pred_m1 = tf.reshape(box_pred_m1,
                                     [-1, 4 * (cfgs.CLASS_NUM + 1)])
            cls_score_m1 = tf.reshape(cls_score_m1, [-1, (cfgs.CLASS_NUM + 1)])
            cls_prob_m1 = slim.softmax(cls_score_m1, scope='cls_prob_m1')

        # 3. generate_anchors
        featuremap_height_m1, featuremap_width_m1 = tf.shape(feature_stride8)[1], \
                                                    tf.shape(feature_stride8)[2]
        featuremap_height_m1 = tf.cast(featuremap_height_m1, tf.float32)
        featuremap_width_m1 = tf.cast(featuremap_width_m1, tf.float32)

        anchors_m1 = anchor_utils.make_anchors(
            base_anchor_size=cfgs.BASE_ANCHOR_SIZE_LIST[0],
            anchor_scales=cfgs.M1_ANCHOR_SCALES,
            anchor_ratios=cfgs.ANCHOR_RATIOS,
            featuremap_height=featuremap_height_m1,
            featuremap_width=featuremap_width_m1,
            stride=[cfgs.ANCHOR_STRIDE[0]],
            name="make_anchors_for_m1")

        featuremap_height_m2, featuremap_width_m2 = tf.shape(feature_stride16)[1], \
                                                    tf.shape(feature_stride16)[2]
        featuremap_height_m2 = tf.cast(featuremap_height_m2, tf.float32)
        featuremap_width_m2 = tf.cast(featuremap_width_m1, tf.float32)

        anchors_m2 = anchor_utils.make_anchors(
            base_anchor_size=cfgs.BASE_ANCHOR_SIZE_LIST[0],
            anchor_scales=cfgs.M2_ANCHOR_SCALES,
            anchor_ratios=cfgs.ANCHOR_RATIOS,
            featuremap_height=featuremap_height_m2,
            featuremap_width=featuremap_width_m2,
            stride=[cfgs.ANCHOR_STRIDE[1]],
            name="make_anchors_for_m2")

        featuremap_height_m3, featuremap_width_m3 = tf.shape(ssh_max_pool)[1], \
                                                    tf.shape(ssh_max_pool)[2]
        featuremap_height_m3 = tf.cast(featuremap_height_m3, tf.float32)
        featuremap_width_m3 = tf.cast(featuremap_width_m3, tf.float32)

        anchors_m3 = anchor_utils.make_anchors(
            base_anchor_size=cfgs.BASE_ANCHOR_SIZE_LIST[0],
            anchor_scales=cfgs.M3_ANCHOR_SCALES,
            anchor_ratios=cfgs.ANCHOR_RATIOS,
            featuremap_height=featuremap_height_m3,
            featuremap_width=featuremap_width_m3,
            stride=[cfgs.ANCHOR_STRIDE[2]],
            name="make_anchors_for_m3")
        # refer to paper: Seeing Small Faces from Robust Anchor’s Perspective
        if cfgs.EXTRA_SHIFTED_ANCHOR:
            shift_anchors_m1 = anchor_utils.shift_anchor(
                anchors_m1, cfgs.ANCHOR_STRIDE[0])
            shift_anchors_m2 = anchor_utils.shift_anchor(
                anchors_m2, cfgs.ANCHOR_STRIDE[1])
            shift_anchors_m3 = anchor_utils.shift_anchor(
                anchors_m3, cfgs.ANCHOR_STRIDE[2])
        else:
            shift_anchors_m1, shift_anchors_m2, shift_anchors_m3 = [], [], []

        if cfgs.FACE_SHIFT_JITTER:
            jitter_anchors_m1 = anchor_utils.shift_jitter(
                anchors_m1, cfgs.ANCHOR_STRIDE[0])
            jitter_anchors_m2 = anchor_utils.shift_jitter(
                anchors_m2, cfgs.ANCHOR_STRIDE[1])
            jitter_anchors_m3 = anchor_utils.shift_jitter(
                anchors_m3, cfgs.ANCHOR_STRIDE[2])
        else:
            jitter_anchors_m1, jitter_anchors_m2, jitter_anchors_m3 = [], [], []

        anchors_m1 = [anchors_m1] + shift_anchors_m1 + jitter_anchors_m1
        anchors_m1 = tf.reshape(tf.stack(anchors_m1, axis=1), [-1, 4])

        anchors_m2 = [anchors_m2] + shift_anchors_m2 + jitter_anchors_m2
        anchors_m2 = tf.reshape(tf.stack(anchors_m2, axis=1), [-1, 4])

        anchors_m3 = [anchors_m3] + shift_anchors_m3 + jitter_anchors_m3
        anchors_m3 = tf.reshape(tf.stack(anchors_m3, axis=1), [-1, 4])

        if self.is_training:
            with tf.variable_scope('sample_ssh_minibatch_m1'):
                rois_m1, labels_m1, bbox_targets_m1, keep_inds_m1 = \
                    tf.py_func(proposal_target_layer,
                               [anchors_m1, gtboxes_batch, 'M1'],
                               [tf.float32, tf.float32, tf.float32, tf.int32])
                rois_m1 = tf.reshape(rois_m1, [-1, 4])
                labels_m1 = tf.to_int32(labels_m1)
                labels_m1 = tf.reshape(labels_m1, [-1])
                bbox_targets_m1 = tf.reshape(bbox_targets_m1,
                                             [-1, 4 * (cfgs.CLASS_NUM + 1)])
                self.add_roi_batch_img_smry(input_img_batch, rois_m1,
                                            labels_m1, 'm1')

            with tf.variable_scope('sample_ssh_minibatch_m2'):
                rois_m2, labels_m2, bbox_targets_m2, keep_inds_m2 = \
                    tf.py_func(proposal_target_layer,
                               [anchors_m2, gtboxes_batch, 'M2'],
                               [tf.float32, tf.float32, tf.float32, tf.int32])
                rois_m2 = tf.reshape(rois_m2, [-1, 4])
                labels_m2 = tf.to_int32(labels_m2)
                labels_m2 = tf.reshape(labels_m2, [-1])
                bbox_targets_m2 = tf.reshape(bbox_targets_m2,
                                             [-1, 4 * (cfgs.CLASS_NUM + 1)])
                self.add_roi_batch_img_smry(input_img_batch, rois_m2,
                                            labels_m2, 'm2')

            with tf.variable_scope('sample_ssh_minibatch_m3'):
                rois_m3, labels_m3, bbox_targets_m3, keep_inds_m3 = \
                    tf.py_func(proposal_target_layer,
                               [anchors_m3, gtboxes_batch, 'M3'],
                               [tf.float32, tf.float32, tf.float32, tf.int32])
                rois_m3 = tf.reshape(rois_m3, [-1, 4])
                labels_m3 = tf.to_int32(labels_m3)
                labels_m3 = tf.reshape(labels_m3, [-1])
                bbox_targets_m3 = tf.reshape(bbox_targets_m3,
                                             [-1, 4 * (cfgs.CLASS_NUM + 1)])
                self.add_roi_batch_img_smry(input_img_batch, rois_m3,
                                            labels_m3, 'm3')

        if not self.is_training:
            with tf.variable_scope('postprocess_ssh_m1'):
                final_bbox_m1, final_scores_m1, final_category_m1 = self.postprocess_ssh(
                    rois=anchors_m1,
                    bbox_ppred=box_pred_m1,
                    scores=cls_prob_m1,
                    img_shape=img_shape,
                    iou_threshold=cfgs.M1_NMS_IOU_THRESHOLD)

            with tf.variable_scope('postprocess_ssh_m2'):
                final_bbox_m2, final_scores_m2, final_category_m2 = self.postprocess_ssh(
                    rois=anchors_m2,
                    bbox_ppred=box_pred_m2,
                    scores=cls_prob_m2,
                    img_shape=img_shape,
                    iou_threshold=cfgs.M2_NMS_IOU_THRESHOLD)

            with tf.variable_scope('postprocess_ssh_m3'):
                final_bbox_m3, final_scores_m3, final_category_m3 = self.postprocess_ssh(
                    rois=anchors_m3,
                    bbox_ppred=box_pred_m3,
                    scores=cls_prob_m3,
                    img_shape=img_shape,
                    iou_threshold=cfgs.M3_NMS_IOU_THRESHOLD)

            result_dict = {
                'final_bbox_m1': final_bbox_m1,
                'final_scores_m1': final_scores_m1,
                'final_category_m1': final_category_m1,
                'final_bbox_m2': final_bbox_m2,
                'final_scores_m2': final_scores_m2,
                'final_category_m2': final_category_m2,
                'final_bbox_m3': final_bbox_m3,
                'final_scores_m3': final_scores_m3,
                'final_category_m3': final_category_m3
            }
            return result_dict

        else:
            with tf.variable_scope('ssh_loss_m1'):

                if not cfgs.M1_MINIBATCH_SIZE == -1:

                    box_pred_m1 = tf.gather(box_pred_m1, keep_inds_m1)
                    cls_score_m1 = tf.gather(cls_score_m1, keep_inds_m1)
                    cls_prob_m1 = tf.reshape(
                        tf.gather(cls_prob_m1, keep_inds_m1),
                        [-1, (cfgs.CLASS_NUM + 1)])

                    bbox_loss_m1 = losses.smooth_l1_loss_rcnn(
                        bbox_pred=box_pred_m1,
                        bbox_targets=bbox_targets_m1,
                        label=labels_m1,
                        num_classes=cfgs.CLASS_NUM + 1,
                        sigma=cfgs.M1_SIGMA)

                    cls_loss_m1 = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            logits=cls_score_m1, labels=labels_m1))

            with tf.variable_scope('postprocess_ssh_m1'):
                final_bbox_m1, final_scores_m1, final_category_m1 = self.postprocess_ssh(
                    rois=rois_m1,
                    bbox_ppred=box_pred_m1,
                    scores=cls_prob_m1,
                    img_shape=img_shape,
                    iou_threshold=cfgs.M2_NMS_IOU_THRESHOLD)

            with tf.variable_scope('ssh_loss_m2'):
                if not cfgs.M2_MINIBATCH_SIZE == -1:

                    box_pred_m2 = tf.gather(box_pred_m2, keep_inds_m2)
                    cls_score_m2 = tf.gather(cls_score_m2, keep_inds_m2)
                    cls_prob_m2 = tf.reshape(
                        tf.gather(cls_prob_m2, keep_inds_m2),
                        [-1, (cfgs.CLASS_NUM + 1)])

                    bbox_loss_m2 = losses.smooth_l1_loss_rcnn(
                        bbox_pred=box_pred_m2,
                        bbox_targets=bbox_targets_m2,
                        label=labels_m2,
                        num_classes=cfgs.CLASS_NUM + 1,
                        sigma=cfgs.M2_SIGMA)

                    cls_loss_m2 = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            logits=cls_score_m2, labels=labels_m2))

            with tf.variable_scope('postprocess_ssh_m2'):
                final_bbox_m2, final_scores_m2, final_category_m2 = self.postprocess_ssh(
                    rois=rois_m2,
                    bbox_ppred=box_pred_m2,
                    scores=cls_prob_m2,
                    img_shape=img_shape,
                    iou_threshold=cfgs.M2_NMS_IOU_THRESHOLD)

            with tf.variable_scope('ssh_loss_m3'):
                if not cfgs.M3_MINIBATCH_SIZE == -1:

                    box_pred_m3 = tf.gather(box_pred_m3, keep_inds_m3)
                    cls_score_m3 = tf.gather(cls_score_m3, keep_inds_m3)
                    cls_prob_m3 = tf.reshape(
                        tf.gather(cls_prob_m3, keep_inds_m3),
                        [-1, (cfgs.CLASS_NUM + 1)])

                    bbox_loss_m3 = losses.smooth_l1_loss_rcnn(
                        bbox_pred=box_pred_m3,
                        bbox_targets=bbox_targets_m3,
                        label=labels_m3,
                        num_classes=cfgs.CLASS_NUM + 1,
                        sigma=cfgs.M3_SIGMA)

                    cls_loss_m3 = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            logits=cls_score_m3, labels=labels_m3))

            with tf.variable_scope('postprocess_ssh_m3'):
                final_bbox_m3, final_scores_m3, final_category_m3 = self.postprocess_ssh(
                    rois=rois_m3,
                    bbox_ppred=box_pred_m3,
                    scores=cls_prob_m3,
                    img_shape=img_shape,
                    iou_threshold=cfgs.M3_NMS_IOU_THRESHOLD)

            result_dict = {
                'final_bbox_m1': final_bbox_m1,
                'final_scores_m1': final_scores_m1,
                'final_category_m1': final_category_m1,
                'final_bbox_m2': final_bbox_m2,
                'final_scores_m2': final_scores_m2,
                'final_category_m2': final_category_m2,
                'final_bbox_m3': final_bbox_m3,
                'final_scores_m3': final_scores_m3,
                'final_category_m3': final_category_m3
            }

            losses_dict = {
                'bbox_loss_m1': bbox_loss_m1,
                'cls_loss_m1': cls_loss_m1,
                'bbox_loss_m2': bbox_loss_m2,
                'cls_loss_m2': cls_loss_m2,
                'bbox_loss_m3': bbox_loss_m3,
                'cls_loss_m3': cls_loss_m3
            }

            return result_dict, losses_dict