Exemplo n.º 1
0
    def _build_graph(self, inputs):
        is_training = get_current_tower_context().is_training
        image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
        image = tf.expand_dims(image, 0)

        # FSxFSxNAx4 (FS=MAX_SIZE//ANCHOR_STRIDE)
        with tf.name_scope('anchors'):
            all_anchors = tf.constant(get_all_anchors(),
                                      name='all_anchors',
                                      dtype=tf.float32)
            fm_anchors = tf.slice(
                all_anchors, [0, 0, 0, 0],
                tf.stack([
                    tf.shape(image)[1] // config.ANCHOR_STRIDE,
                    tf.shape(image)[2] // config.ANCHOR_STRIDE, -1, -1
                ]),
                name='fm_anchors')
            anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)

        image = image_preprocess(image, bgr=True)
        image = tf.transpose(image, [0, 3, 1, 2])

        # resnet50
        featuremap = pretrained_resnet_conv4(image, [3, 4, 6])
        rpn_label_logits, rpn_box_logits = rpn_head(featuremap)
        rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels,
                                                  anchor_boxes_encoded,
                                                  rpn_label_logits,
                                                  rpn_box_logits)

        decoded_boxes = decode_bbox_target(
            rpn_box_logits, fm_anchors)  # (fHxfWxNA)x4, floatbox
        proposal_boxes, proposal_scores = generate_rpn_proposals(
            decoded_boxes, tf.reshape(rpn_label_logits, [-1]),
            tf.shape(image)[2:])

        if is_training:
            rcnn_sampled_boxes, rcnn_encoded_boxes, rcnn_labels = sample_fast_rcnn_targets(
                proposal_boxes, gt_boxes, gt_labels)
            boxes_on_featuremap = rcnn_sampled_boxes * (1.0 /
                                                        config.ANCHOR_STRIDE)
            roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
            feature_fastrcnn = resnet_conv5(roi_resized)  # nxc
            fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head(
                feature_fastrcnn, config.NUM_CLASS)

            fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
                rcnn_labels, rcnn_encoded_boxes, fastrcnn_label_logits,
                fastrcnn_box_logits)

            wd_cost = regularize_cost(
                '(?:group1|group2|group3|rpn|fastrcnn)/.*W',
                l2_regularizer(1e-4),
                name='wd_cost')

            self.cost = tf.add_n([
                rpn_label_loss, rpn_box_loss, fastrcnn_label_loss,
                fastrcnn_box_loss, wd_cost
            ], 'total_cost')

            for k in self.cost, wd_cost:
                add_moving_summary(k)
        else:
            roi_resized = roi_align(
                featuremap, proposal_boxes * (1.0 / config.ANCHOR_STRIDE), 14)
            feature_fastrcnn = resnet_conv5(roi_resized)  # nxc
            label_logits, fastrcnn_box_logits = fastrcnn_head(
                feature_fastrcnn, config.NUM_CLASS)
            label_probs = tf.nn.softmax(label_logits,
                                        name='fastrcnn_all_probs')  # NP,
            labels = tf.argmax(label_logits, axis=1)
            fg_ind, fg_box_logits = fastrcnn_predict_boxes(
                labels, fastrcnn_box_logits)
            fg_label_probs = tf.gather(label_probs,
                                       fg_ind,
                                       name='fastrcnn_fg_probs')
            fg_boxes = tf.gather(proposal_boxes, fg_ind)

            fg_box_logits = fg_box_logits / tf.constant(
                config.FASTRCNN_BBOX_REG_WEIGHTS)
            decoded_boxes = decode_bbox_target(fg_box_logits,
                                               fg_boxes)  # Nfx4, floatbox
            decoded_boxes = tf.identity(decoded_boxes,
                                        name='fastrcnn_fg_boxes')
Exemplo n.º 2
0
    def _build_graph(self, inputs):
        is_training = get_current_tower_context().is_training
        image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
        fm_anchors = self._get_anchors(image)
        image = self._preprocess(image)

        anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
        featuremap = pretrained_resnet_conv4(image,
                                             config.RESNET_NUM_BLOCK[:3])
        rpn_label_logits, rpn_box_logits = rpn_head(featuremap, 1024,
                                                    config.NUM_ANCHOR)
        rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels,
                                                  anchor_boxes_encoded,
                                                  rpn_label_logits,
                                                  rpn_box_logits)

        decoded_boxes = decode_bbox_target(
            rpn_box_logits, fm_anchors,
            config.ANCHOR_STRIDE)  # (fHxfWxNA)x4, floatbox
        proposal_boxes, proposal_scores = generate_rpn_proposals(
            decoded_boxes, tf.reshape(rpn_label_logits, [-1]),
            tf.shape(image)[2:])

        if is_training:
            rcnn_sampled_boxes, rcnn_encoded_boxes, rcnn_labels = sample_fast_rcnn_targets(
                proposal_boxes, gt_boxes, gt_labels)
            boxes_on_featuremap = rcnn_sampled_boxes * (1.0 /
                                                        config.ANCHOR_STRIDE)
            roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
            feature_fastrcnn = resnet_conv5_gap(
                roi_resized, config.RESNET_NUM_BLOCK[-1])  # nxc
            fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head(
                feature_fastrcnn, config.NUM_CLASS)

            fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
                rcnn_labels, rcnn_encoded_boxes, fastrcnn_label_logits,
                fastrcnn_box_logits)

            wd_cost = regularize_cost(
                '(?:group1|group2|group3|rpn|fastrcnn)/.*W',
                l2_regularizer(1e-4),
                name='wd_cost')

            self.cost = tf.add_n([
                rpn_label_loss, rpn_box_loss, fastrcnn_label_loss,
                fastrcnn_box_loss, wd_cost
            ], 'total_cost')

            for k in self.cost, wd_cost:
                add_moving_summary(k)
        else:
            roi_resized = roi_align(
                featuremap, proposal_boxes * (1.0 / config.ANCHOR_STRIDE), 14)
            feature_fastrcnn = resnet_conv5_gap(
                roi_resized, config.RESNET_NUM_BLOCK[-1])  # nxc
            label_logits, fastrcnn_box_logits = fastrcnn_head(
                feature_fastrcnn, config.NUM_CLASS)
            label_probs = tf.nn.softmax(label_logits,
                                        name='fastrcnn_all_probs')  # NP,
            labels = tf.argmax(label_logits, axis=1)
            fg_ind, fg_box_logits = fastrcnn_predict_boxes(
                labels, fastrcnn_box_logits)
            fg_label_probs = tf.gather(label_probs,
                                       fg_ind,
                                       name='fastrcnn_fg_probs')
            fg_boxes = tf.gather(proposal_boxes, fg_ind)

            fg_box_logits = fg_box_logits / tf.constant(
                config.FASTRCNN_BBOX_REG_WEIGHTS)
            decoded_boxes = decode_bbox_target(
                fg_box_logits, fg_boxes,
                config.ANCHOR_STRIDE)  # Nfx4, floatbox
            decoded_boxes = tf.identity(decoded_boxes,
                                        name='fastrcnn_fg_boxes')