Beispiel #1
0
    def build_whole_detection_network(self, rgb_input_img_batch, ir_input_img_batch, gtboxes_batch):

        if self.is_training:
            # ensure shape is [M, 5]
            gtboxes_batch = tf.reshape(gtboxes_batch, [self.batch_size, -1, 5])
            gtboxes_batch = tf.cast(gtboxes_batch, tf.float32)

        img_shape = tf.shape(rgb_input_img_batch)

        feature_pyramid = self.build_base_network(rgb_input_img_batch, ir_input_img_batch)  # [P3, P4, P5, P6, P7]

        rpn_cls_score, rpn_cls_prob, rpn_cnt_scores, rpn_box = self.rpn_net(feature_pyramid)

        rpn_cnt_prob = tf.nn.sigmoid(rpn_cnt_scores)
        rpn_cnt_prob = tf.expand_dims(rpn_cnt_prob, axis=2)
        rpn_cnt_prob = broadcast_to(rpn_cnt_prob,
                                       [self.batch_size, tf.shape(rpn_cls_prob)[1], tf.shape(rpn_cls_prob)[2]])

        rpn_prob = rpn_cls_prob * rpn_cnt_prob

        if not self.is_training:
            with tf.variable_scope('postprocess_detctions'):
                boxes, scores, category = postprocess_detctions(rpn_bbox=rpn_box[0, :, :],
                                                                rpn_cls_prob=rpn_prob[0, :, :],
                                                                img_shape=img_shape)
                return boxes, scores, category
        else:
            with tf.variable_scope('postprocess_detctions'):
                boxes, scores, category = postprocess_detctions(rpn_bbox=rpn_box[0, :, :],
                                                                rpn_cls_prob=rpn_prob[0, :, :],
                                                                img_shape=img_shape)
            with tf.variable_scope('build_loss'):
                fcos_target_batch = self._fcos_target(feature_pyramid, rgb_input_img_batch, gtboxes_batch)

                cls_gt = tf.stop_gradient(fcos_target_batch[:, :, 0])
                ctr_gt = tf.stop_gradient(fcos_target_batch[:, :, 1])
                gt_boxes = tf.stop_gradient(fcos_target_batch[:, :, 2:])

                rpn_cls_loss = losses_fcos.focal_loss(rpn_cls_prob, cls_gt, alpha=cfgs.ALPHA, gamma=cfgs.GAMMA)
                rpn_bbox_loss = losses_fcos.iou_loss(rpn_box, gt_boxes, cls_gt, weight=ctr_gt)
                rpn_ctr_loss = losses_fcos.centerness_loss(rpn_cnt_scores, ctr_gt, cls_gt)
                loss_dict = {
                    'rpn_cls_loss': rpn_cls_loss,
                    'rpn_bbox_loss': rpn_bbox_loss,
                    'rpn_ctr_loss': rpn_ctr_loss
                }

            return boxes, scores, category, loss_dict
Beispiel #2
0
    def build_whole_detection_network(self, rgb_input_img_batch,
                                      ir_input_img_batch, seg_mask_batch,
                                      gtboxes_batch):

        if self.is_training:
            # ensure shape is [M, 5]
            gtboxes_batch = tf.reshape(gtboxes_batch, [self.batch_size, -1, 5])
            gtboxes_batch = tf.cast(gtboxes_batch, tf.float32)

        img_shape = tf.shape(rgb_input_img_batch)

        feature_pyramid_multi, feature_pyramid_rgb, feature_pyramid_ir = self.build_base_network(
            rgb_input_img_batch, ir_input_img_batch)  # [P3, P4, P5, P6, P7]

        multi_cls_score, multi_cls_prob, multi_cnt_scores, multi_box, multi_seg = self.rpn_net(
            feature_pyramid_multi, 'multi')
        #multi_seg = self.seg_net(feature_pyramid_multi, 'multi')
        #multi_seg = tf.reshape(multi_seg, [self.batch_size*len(cfgs.LEVLES), -1, -1, 2])

        multi_cnt_prob = tf.nn.sigmoid(multi_cnt_scores)
        multi_cnt_prob = tf.expand_dims(multi_cnt_prob, axis=2)
        multi_cnt_prob = broadcast_to(multi_cnt_prob, [
            self.batch_size,
            tf.shape(multi_cls_prob)[1],
            tf.shape(multi_cls_prob)[2]
        ])

        multi_prob = multi_cls_prob * multi_cnt_prob

        rgb_cls_score, rgb_cls_prob, rgb_cnt_scores, rgb_box, rgb_seg = self.rpn_net(
            feature_pyramid_rgb, 'rgb')
        #rgb_seg = self.seg_net(feature_pyramid_rgb, 'rgb')
        #rgb_seg = tf.reshape(rgb_seg, [self.batch_size*len(cfgs.LEVLES), -1, -1, 2])

        rgb_cnt_prob = tf.nn.sigmoid(rgb_cnt_scores)
        rgb_cnt_prob = tf.expand_dims(rgb_cnt_prob, axis=2)
        rgb_cnt_prob = broadcast_to(rgb_cnt_prob, [
            self.batch_size,
            tf.shape(rgb_cls_prob)[1],
            tf.shape(rgb_cls_prob)[2]
        ])

        rgb_prob = rgb_cls_prob * rgb_cnt_prob

        ir_cls_score, ir_cls_prob, ir_cnt_scores, ir_box, ir_seg = self.rpn_net(
            feature_pyramid_ir, 'ir')
        #ir_seg = self.seg_net(feature_pyramid_ir, 'ir')
        #ir_seg = tf.reshape(ir_seg, [self.batch_size*len(cfgs.LEVLES), -1, -1, 2])

        ir_cnt_prob = tf.nn.sigmoid(ir_cnt_scores)
        ir_cnt_prob = tf.expand_dims(ir_cnt_prob, axis=2)
        ir_cnt_prob = broadcast_to(ir_cnt_prob, [
            self.batch_size,
            tf.shape(ir_cls_prob)[1],
            tf.shape(ir_cls_prob)[2]
        ])

        ir_prob = ir_cls_prob * ir_cnt_prob

        rpn_box = tf.concat([multi_box, rgb_box, ir_box], axis=1)
        rpn_prob = tf.concat([multi_prob, rgb_prob, ir_prob], axis=1)
        rpn_cls_prob = tf.concat([multi_cls_prob, rgb_cls_prob, ir_cls_prob],
                                 axis=1)
        rpn_cnt_scores = tf.concat(
            [multi_cnt_scores, rgb_cnt_scores, ir_cnt_scores], axis=1)

        if not self.is_training:
            with tf.variable_scope('postprocess_detctions'):
                boxes, scores, category = postprocess_detctions(
                    rpn_bbox=rpn_box[0, :, :],
                    rpn_cls_prob=rpn_prob[0, :, :],
                    img_shape=img_shape)
                return boxes, scores, category
        else:
            with tf.variable_scope('postprocess_detctions'):
                boxes, scores, category = postprocess_detctions(
                    rpn_bbox=rpn_box[0, :, :],
                    rpn_cls_prob=rpn_prob[0, :, :],
                    img_shape=img_shape)
            with tf.variable_scope('build_loss'):
                fcos_target_bat = self._fcos_target(feature_pyramid_multi,
                                                    rgb_input_img_batch,
                                                    gtboxes_batch)
                multi_seg_loss = []
                rgb_seg_loss = []
                ir_seg_loss = []
                for i, levels in enumerate(cfgs.LEVLES):
                    seg_target_batch = self._seg_target(
                        feature_pyramid_multi, seg_mask_batch, levels)

                    multi_seg_loss.append(
                        tf.reduce_mean(
                            tf.nn.sparse_softmax_cross_entropy_with_logits(
                                logits=multi_seg[i], labels=seg_target_batch)))
                    rgb_seg_loss.append(
                        tf.reduce_mean(
                            tf.nn.sparse_softmax_cross_entropy_with_logits(
                                logits=rgb_seg[i], labels=seg_target_batch)))
                    ir_seg_loss.append(
                        tf.reduce_mean(
                            tf.nn.sparse_softmax_cross_entropy_with_logits(
                                logits=ir_seg[i], labels=seg_target_batch)))

                fcos_target_batch = tf.concat(
                    [fcos_target_bat, fcos_target_bat, fcos_target_bat],
                    axis=1)

                cls_gt = tf.stop_gradient(fcos_target_batch[:, :, 0])
                ctr_gt = tf.stop_gradient(fcos_target_batch[:, :, 1])
                gt_boxes = tf.stop_gradient(fcos_target_batch[:, :, 2:])

                rpn_cls_loss = losses_fcos.focal_loss(rpn_cls_prob,
                                                      cls_gt,
                                                      alpha=cfgs.ALPHA,
                                                      gamma=cfgs.GAMMA)
                rpn_bbox_loss = losses_fcos.iou_loss(rpn_box,
                                                     gt_boxes,
                                                     cls_gt,
                                                     weight=ctr_gt)
                rpn_ctr_loss = losses_fcos.centerness_loss(
                    rpn_cnt_scores, ctr_gt, cls_gt)

                loss_dict = {
                    'rpn_cls_loss': rpn_cls_loss,
                    'rpn_bbox_loss': rpn_bbox_loss,
                    'rpn_ctr_loss': rpn_ctr_loss,
                    'multi_seg_loss':
                    sum(multi_seg_loss) / len(multi_seg_loss),
                    'rgb_seg_loss': sum(rgb_seg_loss) / len(rgb_seg_loss),
                    'ir_seg_loss': sum(ir_seg_loss) / len(ir_seg_loss),
                }

            return boxes, scores, category, loss_dict