Exemple #1
0
    def fastrcnn_inference(self, image_shape2d, rcnn_boxes, rcnn_label_logits,
                           rcnn_box_logits):
        """
        Args:
            image_shape2d: h, w
            rcnn_boxes (nx4): the proposal boxes
            rcnn_label_logits (n):
            rcnn_box_logits (nx4):

        Returns:
            boxes (mx4):
            labels (m): each >= 1
        """
        label_probs = tf.nn.softmax(
            rcnn_label_logits, name='fastrcnn_all_probs')  # #proposal x #Class
        anchors = tf.tile(tf.expand_dims(rcnn_boxes, 1),
                          [1, config.NUM_CLASS - 1, 1])  # #proposal x #Cat x 4
        decoded_boxes = decode_bbox_target(
            rcnn_box_logits / tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS),
            anchors)
        decoded_boxes = clip_boxes(decoded_boxes,
                                   image_shape2d,
                                   name='fastrcnn_all_boxes')

        # indices: Nx2. Each index into (#proposal, #category)
        pred_indices, final_probs = fastrcnn_predictions(
            decoded_boxes, label_probs)
        final_probs = tf.identity(final_probs, 'final_probs')
        final_boxes = tf.gather_nd(decoded_boxes,
                                   pred_indices,
                                   name='final_boxes')
        final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels')
        return final_boxes, final_labels
Exemple #2
0
    def fastrcnn_inference(self, image_shape2d,
                           rcnn_boxes, rcnn_label_logits, rcnn_box_logits):
        """
        Args:
            image_shape2d: h, w
            rcnn_boxes (nx4): the proposal boxes
            rcnn_label_logits (n):
            rcnn_box_logits (nx4):

        Returns:
            boxes (mx4):
            labels (m): each >= 1
        """
        label_probs = tf.nn.softmax(rcnn_label_logits, name='fastrcnn_all_probs')  # #proposal x #Class
        anchors = tf.tile(tf.expand_dims(rcnn_boxes, 1), [1, config.NUM_CLASS - 1, 1])   # #proposal x #Cat x 4
        decoded_boxes = decode_bbox_target(
            rcnn_box_logits /
            tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors)
        decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')

        # indices: Nx2. Each index into (#proposal, #category)
        pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs)
        final_probs = tf.identity(final_probs, 'final_probs')
        final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes')
        final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels')
        return final_boxes, final_labels
Exemple #3
0
    def _build_graph(self, inputs):
        is_training = get_current_tower_context().is_training
        if config.MODE_MASK:
            image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs
        else:
            image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
        fm_anchors = self._get_anchors(image)
        image = self._preprocess(image)  # 1CHW
        image_shape2d = tf.shape(image)[2:]

        anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
        featuremap = pretrained_resnet_conv4(image,
                                             config.RESNET_NUM_BLOCK[:3])
        rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024,
                                                    config.NUM_ANCHOR)

        decoded_boxes = decode_bbox_target(rpn_box_logits,
                                           fm_anchors)  # fHxfWxNAx4, floatbox
        proposal_boxes, proposal_scores = generate_rpn_proposals(
            tf.reshape(decoded_boxes, [-1, 4]),
            tf.reshape(rpn_label_logits, [-1]), image_shape2d)

        if is_training:
            # sample proposal boxes in training
            rcnn_sampled_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
                proposal_boxes, gt_boxes, gt_labels)
            boxes_on_featuremap = rcnn_sampled_boxes * (1.0 /
                                                        config.ANCHOR_STRIDE)
        else:
            # use all proposal boxes in inference
            boxes_on_featuremap = proposal_boxes * (1.0 / config.ANCHOR_STRIDE)

        roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)

        # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
        def ff_true():
            feature_fastrcnn = resnet_conv5(
                roi_resized, config.RESNET_NUM_BLOCK[-1])  # nxcx7x7
            fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head(
                'fastrcnn', feature_fastrcnn, config.NUM_CLASS)
            return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits

        def ff_false():
            ncls = config.NUM_CLASS
            return tf.zeros([0, 2048, 7,
                             7]), tf.zeros([0,
                                            ncls]), tf.zeros([0, ncls - 1, 4])

        feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond(
            tf.size(boxes_on_featuremap) > 0, ff_true, ff_false)

        if is_training:
            # rpn loss
            rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels,
                                                      anchor_boxes_encoded,
                                                      rpn_label_logits,
                                                      rpn_box_logits)

            # fastrcnn loss
            fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0),
                                            [-1])  # fg inds w.r.t all samples
            fg_sampled_boxes = tf.gather(rcnn_sampled_boxes,
                                         fg_inds_wrt_sample)

            with tf.name_scope('fg_sample_patch_viz'):
                fg_sampled_patches = crop_and_resize(
                    image, fg_sampled_boxes,
                    tf.zeros_like(fg_inds_wrt_sample, dtype=tf.int32), 300)
                fg_sampled_patches = tf.transpose(fg_sampled_patches,
                                                  [0, 2, 3, 1])
                fg_sampled_patches = tf.reverse(fg_sampled_patches,
                                                axis=[-1])  # BGR->RGB
                tf.summary.image('viz', fg_sampled_patches, max_outputs=30)

            matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
            encoded_boxes = encode_bbox_target(
                matched_gt_boxes, fg_sampled_boxes) * tf.constant(
                    config.FASTRCNN_BBOX_REG_WEIGHTS)
            fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
                rcnn_labels, fastrcnn_label_logits, encoded_boxes,
                tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample))

            if config.MODE_MASK:
                # maskrcnn loss
                fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
                fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample)
                mask_logits = maskrcnn_head(
                    'maskrcnn', fg_feature,
                    config.NUM_CLASS)  # #fg x #cat x 14x14

                gt_masks_for_fg = tf.gather(gt_masks,
                                            fg_inds_wrt_gt)  # nfg x H x W
                target_masks_for_fg = crop_and_resize(
                    tf.expand_dims(gt_masks_for_fg, 1), fg_sampled_boxes,
                    tf.range(tf.size(fg_inds_wrt_gt)), 14)  # nfg x 1x14x14
                target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1,
                                                 'sampled_fg_mask_targets')
                mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels,
                                           target_masks_for_fg)
            else:
                mrcnn_loss = 0.0

            wd_cost = regularize_cost(
                '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
                l2_regularizer(1e-4),
                name='wd_cost')

            self.cost = tf.add_n([
                rpn_label_loss, rpn_box_loss, fastrcnn_label_loss,
                fastrcnn_box_loss, mrcnn_loss, wd_cost
            ], 'total_cost')

            add_moving_summary(self.cost, wd_cost)
        else:
            label_probs = tf.nn.softmax(
                fastrcnn_label_logits,
                name='fastrcnn_all_probs')  # #proposal x #Class
            anchors = tf.tile(
                tf.expand_dims(proposal_boxes, 1),
                [1, config.NUM_CLASS - 1, 1])  # #proposal x #Cat x 4
            decoded_boxes = decode_bbox_target(
                fastrcnn_box_logits /
                tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors)
            decoded_boxes = clip_boxes(decoded_boxes,
                                       image_shape2d,
                                       name='fastrcnn_all_boxes')

            # indices: Nx2. Each index into (#proposal, #category)
            pred_indices, final_probs = fastrcnn_predictions(
                decoded_boxes, label_probs)
            final_probs = tf.identity(final_probs, 'final_probs')
            final_boxes = tf.gather_nd(decoded_boxes,
                                       pred_indices,
                                       name='final_boxes')
            final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels')

            if config.MODE_MASK:
                # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
                def f1():
                    roi_resized = roi_align(
                        featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE),
                        14)
                    feature_maskrcnn = resnet_conv5(
                        roi_resized, config.RESNET_NUM_BLOCK[-1])
                    mask_logits = maskrcnn_head(
                        'maskrcnn', feature_maskrcnn,
                        config.NUM_CLASS)  # #result x #cat x 14x14
                    indices = tf.stack([
                        tf.range(tf.size(final_labels)),
                        tf.to_int32(final_labels) - 1
                    ],
                                       axis=1)
                    final_mask_logits = tf.gather_nd(mask_logits,
                                                     indices)  # #resultx14x14
                    return tf.sigmoid(final_mask_logits)

                final_masks = tf.cond(
                    tf.size(final_probs) > 0, f1,
                    lambda: tf.zeros([0, 14, 14]))
                tf.identity(final_masks, name='final_masks')
Exemple #4
0
    def _build_graph(self, inputs):
        is_training = get_current_tower_context().is_training
        if config.MODE_MASK:
            image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs
        else:
            image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
        fm_anchors = self._get_anchors(image)
        image = self._preprocess(image)     # 1CHW
        image_shape2d = tf.shape(image)[2:]

        anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
        featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3])
        rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR)

        decoded_boxes = decode_bbox_target(rpn_box_logits, fm_anchors)  # fHxfWxNAx4, floatbox
        proposal_boxes, proposal_scores = generate_rpn_proposals(
            tf.reshape(decoded_boxes, [-1, 4]),
            tf.reshape(rpn_label_logits, [-1]),
            image_shape2d)

        if is_training:
            # sample proposal boxes in training
            rcnn_sampled_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
                proposal_boxes, gt_boxes, gt_labels)
            boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE)
        else:
            # use all proposal boxes in inference
            boxes_on_featuremap = proposal_boxes * (1.0 / config.ANCHOR_STRIDE)

        roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)

        # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
        def ff_true():
            feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1])    # nxcx7x7
            fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head('fastrcnn', feature_fastrcnn, config.NUM_CLASS)
            return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits

        def ff_false():
            ncls = config.NUM_CLASS
            return tf.zeros([0, 2048, 7, 7]), tf.zeros([0, ncls]), tf.zeros([0, ncls - 1, 4])

        feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond(
            tf.size(boxes_on_featuremap) > 0, ff_true, ff_false)

        if is_training:
            # rpn loss
            rpn_label_loss, rpn_box_loss = rpn_losses(
                anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits)

            # fastrcnn loss
            fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1])   # fg inds w.r.t all samples
            fg_sampled_boxes = tf.gather(rcnn_sampled_boxes, fg_inds_wrt_sample)

            with tf.name_scope('fg_sample_patch_viz'):
                fg_sampled_patches = crop_and_resize(
                    image, fg_sampled_boxes,
                    tf.zeros_like(fg_inds_wrt_sample, dtype=tf.int32), 300)
                fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1])
                fg_sampled_patches = tf.reverse(fg_sampled_patches, axis=[-1])  # BGR->RGB
                tf.summary.image('viz', fg_sampled_patches, max_outputs=30)

            matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
            encoded_boxes = encode_bbox_target(
                matched_gt_boxes,
                fg_sampled_boxes) * tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS)
            fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
                rcnn_labels, fastrcnn_label_logits,
                encoded_boxes,
                tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample))

            if config.MODE_MASK:
                # maskrcnn loss
                fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
                fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample)
                mask_logits = maskrcnn_head('maskrcnn', fg_feature, config.NUM_CLASS)   # #fg x #cat x 14x14

                gt_masks_for_fg = tf.gather(gt_masks, fg_inds_wrt_gt)  # nfg x H x W
                target_masks_for_fg = crop_and_resize(
                    tf.expand_dims(gt_masks_for_fg, 1),
                    fg_sampled_boxes,
                    tf.range(tf.size(fg_inds_wrt_gt)), 14)  # nfg x 1x14x14
                target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
                mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg)
            else:
                mrcnn_loss = 0.0

            wd_cost = regularize_cost(
                '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
                l2_regularizer(1e-4), name='wd_cost')

            self.cost = tf.add_n([
                rpn_label_loss, rpn_box_loss,
                fastrcnn_label_loss, fastrcnn_box_loss,
                mrcnn_loss,
                wd_cost], 'total_cost')

            add_moving_summary(self.cost, wd_cost)
        else:
            label_probs = tf.nn.softmax(fastrcnn_label_logits, name='fastrcnn_all_probs')  # #proposal x #Class
            anchors = tf.tile(tf.expand_dims(proposal_boxes, 1), [1, config.NUM_CLASS - 1, 1])   # #proposal x #Cat x 4
            decoded_boxes = decode_bbox_target(
                fastrcnn_box_logits /
                tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors)
            decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')

            # indices: Nx2. Each index into (#proposal, #category)
            pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs)
            final_probs = tf.identity(final_probs, 'final_probs')
            final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes')
            final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels')

            if config.MODE_MASK:
                # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
                def f1():
                    roi_resized = roi_align(featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE), 14)
                    feature_maskrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1])
                    mask_logits = maskrcnn_head(
                        'maskrcnn', feature_maskrcnn, config.NUM_CLASS)   # #result x #cat x 14x14
                    indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
                    final_mask_logits = tf.gather_nd(mask_logits, indices)   # #resultx14x14
                    return tf.sigmoid(final_mask_logits)

                final_masks = tf.cond(tf.size(final_probs) > 0, f1, lambda: tf.zeros([0, 14, 14]))
                tf.identity(final_masks, name='final_masks')
Exemple #5
0
    def _build_graph(self, inputs):
        is_training = get_current_tower_context().is_training
        image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
        fm_anchors = self._get_anchors(image)
        image = self._preprocess(image)  # 1CHW
        image_shape2d = tf.shape(image)[2:]

        anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
        featuremap = pretrained_resnet_conv4(image,
                                             config.RESNET_NUM_BLOCK[:3])
        rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024,
                                                    config.NUM_ANCHOR)

        decoded_boxes = decode_bbox_target(rpn_box_logits,
                                           fm_anchors)  # fHxfWxNAx4, floatbox
        proposal_boxes, proposal_scores = generate_rpn_proposals(
            tf.reshape(decoded_boxes, [-1, 4]),
            tf.reshape(rpn_label_logits, [-1]), image_shape2d)

        if is_training:
            # sample proposal boxes in training
            rcnn_sampled_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
                proposal_boxes, gt_boxes, gt_labels)
            boxes_on_featuremap = rcnn_sampled_boxes * (1.0 /
                                                        config.ANCHOR_STRIDE)
        else:
            # use all proposal boxes in inference
            boxes_on_featuremap = proposal_boxes * (1.0 / config.ANCHOR_STRIDE)

        roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
        feature_fastrcnn = resnet_conv5(roi_resized,
                                        config.RESNET_NUM_BLOCK[-1])  # nxcx7x7
        fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head(
            'fastrcnn', feature_fastrcnn, config.NUM_CLASS)

        if is_training:
            # rpn loss
            rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels,
                                                      anchor_boxes_encoded,
                                                      rpn_label_logits,
                                                      rpn_box_logits)

            # fastrcnn loss
            fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0),
                                            [-1])  # fg inds w.r.t all samples
            fg_sampled_boxes = tf.gather(rcnn_sampled_boxes,
                                         fg_inds_wrt_sample)

            with tf.name_scope('fg_sample_patch_viz'):
                fg_sampled_patches = crop_and_resize(
                    image, fg_sampled_boxes,
                    tf.zeros_like(fg_inds_wrt_sample, dtype=tf.int32), 300)
                fg_sampled_patches = tf.transpose(fg_sampled_patches,
                                                  [0, 2, 3, 1])
                tf.summary.image('viz', fg_sampled_patches, max_outputs=30)

            matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
            encoded_boxes = encode_bbox_target(
                matched_gt_boxes, fg_sampled_boxes) * tf.constant(
                    config.FASTRCNN_BBOX_REG_WEIGHTS)
            fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
                rcnn_labels, fastrcnn_label_logits, encoded_boxes,
                tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample))

            wd_cost = regularize_cost(
                '(?:group1|group2|group3|rpn|fastrcnn)/.*W',
                l2_regularizer(1e-4),
                name='wd_cost')

            self.cost = tf.add_n([
                rpn_label_loss, rpn_box_loss, fastrcnn_label_loss,
                fastrcnn_box_loss, wd_cost
            ], 'total_cost')

            add_moving_summary(self.cost, wd_cost)
        else:
            label_probs = tf.nn.softmax(
                fastrcnn_label_logits,
                name='fastrcnn_all_probs')  # #proposal x #Class
            anchors = tf.tile(
                tf.expand_dims(proposal_boxes, 1),
                [1, config.NUM_CLASS - 1, 1])  # #proposal x #Cat x 4
            decoded_boxes = decode_bbox_target(
                fastrcnn_box_logits /
                tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors)
            decoded_boxes = clip_boxes(decoded_boxes,
                                       image_shape2d,
                                       name='fastrcnn_all_boxes')

            # indices: Nx2. Each index into (#proposal, #category)
            pred_indices, final_probs = fastrcnn_predictions(
                decoded_boxes, label_probs)
            final_probs = tf.identity(final_probs, 'final_probs')
            tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes')
            tf.add(pred_indices[:, 1], 1, name='final_labels')
Exemple #6
0
    def _build_graph(self, inputs):
        is_training = get_current_tower_context().is_training
        image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
        image = tf.expand_dims(image, 0)

        # FSxFSxNAx4 (FS=MAX_SIZE//ANCHOR_STRIDE)
        with tf.name_scope('anchors'):
            all_anchors = tf.constant(get_all_anchors(),
                                      name='all_anchors',
                                      dtype=tf.float32)
            fm_anchors = tf.slice(
                all_anchors, [0, 0, 0, 0],
                tf.stack([
                    tf.shape(image)[1] // config.ANCHOR_STRIDE,
                    tf.shape(image)[2] // config.ANCHOR_STRIDE, -1, -1
                ]),
                name='fm_anchors')
            anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)

        image = image_preprocess(image, bgr=True)
        image = tf.transpose(image, [0, 3, 1, 2])

        # resnet50
        featuremap = pretrained_resnet_conv4(image, [3, 4, 6])
        rpn_label_logits, rpn_box_logits = rpn_head(featuremap)
        rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels,
                                                  anchor_boxes_encoded,
                                                  rpn_label_logits,
                                                  rpn_box_logits)

        decoded_boxes = decode_bbox_target(
            rpn_box_logits, fm_anchors)  # (fHxfWxNA)x4, floatbox
        proposal_boxes, proposal_scores = generate_rpn_proposals(
            decoded_boxes, tf.reshape(rpn_label_logits, [-1]),
            tf.shape(image)[2:])

        if is_training:
            rcnn_sampled_boxes, rcnn_encoded_boxes, rcnn_labels = sample_fast_rcnn_targets(
                proposal_boxes, gt_boxes, gt_labels)
            boxes_on_featuremap = rcnn_sampled_boxes * (1.0 /
                                                        config.ANCHOR_STRIDE)
            roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
            feature_fastrcnn = resnet_conv5(roi_resized)  # nxc
            fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head(
                feature_fastrcnn, config.NUM_CLASS)

            fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
                rcnn_labels, rcnn_encoded_boxes, fastrcnn_label_logits,
                fastrcnn_box_logits)

            wd_cost = regularize_cost(
                '(?:group1|group2|group3|rpn|fastrcnn)/.*W',
                l2_regularizer(1e-4),
                name='wd_cost')

            self.cost = tf.add_n([
                rpn_label_loss, rpn_box_loss, fastrcnn_label_loss,
                fastrcnn_box_loss, wd_cost
            ], 'total_cost')

            for k in self.cost, wd_cost:
                add_moving_summary(k)
        else:
            roi_resized = roi_align(
                featuremap, proposal_boxes * (1.0 / config.ANCHOR_STRIDE), 14)
            feature_fastrcnn = resnet_conv5(roi_resized)  # nxc
            label_logits, fastrcnn_box_logits = fastrcnn_head(
                feature_fastrcnn, config.NUM_CLASS)
            label_probs = tf.nn.softmax(label_logits,
                                        name='fastrcnn_all_probs')  # NP,
            labels = tf.argmax(label_logits, axis=1)
            fg_ind, fg_box_logits = fastrcnn_predict_boxes(
                labels, fastrcnn_box_logits)
            fg_label_probs = tf.gather(label_probs,
                                       fg_ind,
                                       name='fastrcnn_fg_probs')
            fg_boxes = tf.gather(proposal_boxes, fg_ind)

            fg_box_logits = fg_box_logits / tf.constant(
                config.FASTRCNN_BBOX_REG_WEIGHTS)
            decoded_boxes = decode_bbox_target(fg_box_logits,
                                               fg_boxes)  # Nfx4, floatbox
            decoded_boxes = tf.identity(decoded_boxes,
                                        name='fastrcnn_fg_boxes')
Exemple #7
0
    def build_graph(self, *inputs):
        num_fpn_level = len(config.ANCHOR_STRIDES_FPN)
        assert len(config.ANCHOR_SIZES) == num_fpn_level
        is_training = get_current_tower_context().is_training
        image = inputs[0]
        input_anchors = inputs[1:1 + 2 * num_fpn_level]
        multilevel_anchor_labels = input_anchors[0::2]
        multilevel_anchor_boxes = input_anchors[1::2]
        gt_boxes, gt_labels = inputs[11], inputs[12]
        if config.MODE_MASK:
            gt_masks = inputs[-1]

        image = self.preprocess(image)  # 1CHW
        image_shape2d = tf.shape(image)[2:]  # h,w

        c2345 = resnet_fpn_backbone(image, config.RESNET_NUM_BLOCK)
        p23456 = fpn_model('fpn', c2345)

        # Multi-Level RPN Proposals
        multilevel_proposals = []
        rpn_loss_collection = []
        for lvl in range(num_fpn_level):
            rpn_label_logits, rpn_box_logits = rpn_head(
                'rpn', p23456[lvl], config.FPN_NUM_CHANNEL,
                len(config.ANCHOR_RATIOS))
            with tf.name_scope('FPN_lvl{}'.format(lvl + 2)):
                anchors = tf.constant(get_all_anchors_fpn()[lvl],
                                      name='rpn_anchor_lvl{}'.format(lvl + 2))
                anchors, anchor_labels, anchor_boxes = \
                    self.narrow_to_featuremap(p23456[lvl], anchors,
                                              multilevel_anchor_labels[lvl],
                                              multilevel_anchor_boxes[lvl])
                anchor_boxes_encoded = encode_bbox_target(
                    anchor_boxes, anchors)
                pred_boxes_decoded = decode_bbox_target(
                    rpn_box_logits, anchors)
                proposal_boxes, proposal_scores = generate_rpn_proposals(
                    tf.reshape(pred_boxes_decoded, [-1, 4]),
                    tf.reshape(rpn_label_logits, [-1]), image_shape2d,
                    config.TRAIN_FPN_NMS_TOPK
                    if is_training else config.TEST_FPN_NMS_TOPK)
                multilevel_proposals.append((proposal_boxes, proposal_scores))
                if is_training:
                    label_loss, box_loss = rpn_losses(anchor_labels,
                                                      anchor_boxes_encoded,
                                                      rpn_label_logits,
                                                      rpn_box_logits)
                    rpn_loss_collection.extend([label_loss, box_loss])

        # Merge proposals from multi levels, pick top K
        proposal_boxes = tf.concat([x[0] for x in multilevel_proposals],
                                   axis=0)  # nx4
        proposal_scores = tf.concat([x[1] for x in multilevel_proposals],
                                    axis=0)  # n
        proposal_topk = tf.minimum(
            tf.size(proposal_scores), config.TRAIN_FPN_NMS_TOPK
            if is_training else config.TEST_FPN_NMS_TOPK)
        proposal_scores, topk_indices = tf.nn.top_k(proposal_scores,
                                                    k=proposal_topk,
                                                    sorted=False)
        proposal_boxes = tf.gather(proposal_boxes, topk_indices)

        if is_training:
            rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
                proposal_boxes, gt_boxes, gt_labels)
        else:
            # The boxes to be used to crop RoIs.
            rcnn_boxes = proposal_boxes

        roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7)

        fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_2fc_head(
            'fastrcnn', roi_feature_fastrcnn, config.NUM_CLASS)

        if is_training:
            # rpn loss is already defined above
            with tf.name_scope('rpn_losses'):
                rpn_total_label_loss = tf.add_n(rpn_loss_collection[::2],
                                                name='label_loss')
                rpn_total_box_loss = tf.add_n(rpn_loss_collection[1::2],
                                              name='box_loss')
                add_moving_summary(rpn_total_box_loss, rpn_total_label_loss)

            # fastrcnn loss:
            matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)

            fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0),
                                            [-1])  # fg inds w.r.t all samples
            fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample)
            fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits,
                                               fg_inds_wrt_sample)

            fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training(
                image, rcnn_labels, fg_sampled_boxes, matched_gt_boxes,
                fastrcnn_label_logits, fg_fastrcnn_box_logits)

            if config.MODE_MASK:
                # maskrcnn loss
                fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
                roi_feature_maskrcnn = multilevel_roi_align(
                    p23456[:4], fg_sampled_boxes, 14)
                mask_logits = maskrcnn_upXconv_head('maskrcnn',
                                                    roi_feature_maskrcnn,
                                                    config.NUM_CLASS,
                                                    4)  # #fg x #cat x 28 x 28

                matched_gt_masks = tf.gather(gt_masks,
                                             fg_inds_wrt_gt)  # fg x H x W
                target_masks_for_fg = crop_and_resize(
                    tf.expand_dims(matched_gt_masks, 1),
                    fg_sampled_boxes,
                    tf.range(tf.size(fg_inds_wrt_gt)),
                    28,
                    pad_border=False)  # fg x 1x28x28
                target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1,
                                                 'sampled_fg_mask_targets')
                mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels,
                                           target_masks_for_fg)
            else:
                mrcnn_loss = 0.0

            wd_cost = regularize_cost(
                '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
                l2_regularizer(1e-4),
                name='wd_cost')

            total_cost = tf.add_n(
                rpn_loss_collection +
                [fastrcnn_label_loss, fastrcnn_box_loss, mrcnn_loss, wd_cost],
                'total_cost')

            add_moving_summary(total_cost, wd_cost)
            return total_cost
        else:
            final_boxes, final_labels = self.fastrcnn_inference(
                image_shape2d, rcnn_boxes, fastrcnn_label_logits,
                fastrcnn_box_logits)
            if config.MODE_MASK:
                # Cascade inference needs roi transform with refined boxes.
                roi_feature_maskrcnn = multilevel_roi_align(
                    p23456[:4], final_boxes, 14)
                mask_logits = maskrcnn_upXconv_head('maskrcnn',
                                                    roi_feature_maskrcnn,
                                                    config.NUM_CLASS,
                                                    4)  # #fg x #cat x 28 x 28
                indices = tf.stack([
                    tf.range(tf.size(final_labels)),
                    tf.to_int32(final_labels) - 1
                ],
                                   axis=1)
                final_mask_logits = tf.gather_nd(mask_logits,
                                                 indices)  # #resultx28x28
                tf.sigmoid(final_mask_logits, name='final_masks')
Exemple #8
0
    def build_graph(self, *inputs):
        is_training = get_current_tower_context().is_training
        if config.MODE_MASK:
            image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs
        else:
            image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
        image = self.preprocess(image)  # 1CHW

        featuremap = resnet_c4_backbone(image, config.RESNET_NUM_BLOCK[:3])
        rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024,
                                                    config.NUM_ANCHOR)

        fm_anchors, anchor_labels, anchor_boxes = self.narrow_to_featuremap(
            featuremap, get_all_anchors(), anchor_labels, anchor_boxes)
        anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)

        image_shape2d = tf.shape(image)[2:]  # h,w
        pred_boxes_decoded = decode_bbox_target(
            rpn_box_logits, fm_anchors)  # fHxfWxNAx4, floatbox
        proposal_boxes, proposal_scores = generate_rpn_proposals(
            tf.reshape(pred_boxes_decoded, [-1, 4]),
            tf.reshape(rpn_label_logits,
                       [-1]), image_shape2d, config.TRAIN_PRE_NMS_TOPK
            if is_training else config.TEST_PRE_NMS_TOPK,
            config.TRAIN_POST_NMS_TOPK
            if is_training else config.TEST_POST_NMS_TOPK)

        if is_training:
            # sample proposal boxes in training
            rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
                proposal_boxes, gt_boxes, gt_labels)
        else:
            # The boxes to be used to crop RoIs.
            # Use all proposal boxes in inference
            rcnn_boxes = proposal_boxes

        boxes_on_featuremap = rcnn_boxes * (1.0 / config.ANCHOR_STRIDE)
        roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)

        # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
        # which was fixed in TF 1.6
        def ff_true():
            feature_fastrcnn = resnet_conv5(
                roi_resized, config.RESNET_NUM_BLOCK[-1])  # nxcx7x7
            feature_gap = GlobalAvgPooling('gap',
                                           feature_fastrcnn,
                                           data_format='channels_first')
            fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs(
                'fastrcnn', feature_gap, config.NUM_CLASS)
            # Return C5 feature to be shared with mask branch
            return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits

        def ff_false():
            ncls = config.NUM_CLASS
            return tf.zeros([0, 2048, 7,
                             7]), tf.zeros([0,
                                            ncls]), tf.zeros([0, ncls - 1, 4])

        if get_tf_version_number() >= 1.6:
            feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = ff_true(
            )
        else:
            logger.warn("This example may drop support for TF < 1.6 soon.")
            feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond(
                tf.size(boxes_on_featuremap) > 0, ff_true, ff_false)

        if is_training:
            # rpn loss
            rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels,
                                                      anchor_boxes_encoded,
                                                      rpn_label_logits,
                                                      rpn_box_logits)

            # fastrcnn loss
            matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)

            fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0),
                                            [-1])  # fg inds w.r.t all samples
            fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample)
            fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits,
                                               fg_inds_wrt_sample)

            fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training(
                image, rcnn_labels, fg_sampled_boxes, matched_gt_boxes,
                fastrcnn_label_logits, fg_fastrcnn_box_logits)

            if config.MODE_MASK:
                # maskrcnn loss
                fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
                # In training, mask branch shares the same C5 feature.
                fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample)
                mask_logits = maskrcnn_upXconv_head(
                    'maskrcnn', fg_feature, config.NUM_CLASS,
                    num_convs=0)  # #fg x #cat x 14x14

                matched_gt_masks = tf.gather(gt_masks,
                                             fg_inds_wrt_gt)  # nfg x H x W
                target_masks_for_fg = crop_and_resize(
                    tf.expand_dims(matched_gt_masks, 1),
                    fg_sampled_boxes,
                    tf.range(tf.size(fg_inds_wrt_gt)),
                    14,
                    pad_border=False)  # nfg x 1x14x14
                target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1,
                                                 'sampled_fg_mask_targets')
                mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels,
                                           target_masks_for_fg)
            else:
                mrcnn_loss = 0.0

            wd_cost = regularize_cost(
                '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
                l2_regularizer(1e-4),
                name='wd_cost')

            total_cost = tf.add_n([
                rpn_label_loss, rpn_box_loss, fastrcnn_label_loss,
                fastrcnn_box_loss, mrcnn_loss, wd_cost
            ], 'total_cost')

            add_moving_summary(total_cost, wd_cost)
            return total_cost
        else:
            final_boxes, final_labels = self.fastrcnn_inference(
                image_shape2d, rcnn_boxes, fastrcnn_label_logits,
                fastrcnn_box_logits)

            if config.MODE_MASK:
                # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
                def f1():
                    roi_resized = roi_align(
                        featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE),
                        14)
                    feature_maskrcnn = resnet_conv5(
                        roi_resized, config.RESNET_NUM_BLOCK[-1])
                    mask_logits = maskrcnn_upXconv_head(
                        'maskrcnn', feature_maskrcnn, config.NUM_CLASS,
                        0)  # #result x #cat x 14x14
                    indices = tf.stack([
                        tf.range(tf.size(final_labels)),
                        tf.to_int32(final_labels) - 1
                    ],
                                       axis=1)
                    final_mask_logits = tf.gather_nd(mask_logits,
                                                     indices)  # #resultx14x14
                    return tf.sigmoid(final_mask_logits)

                final_masks = tf.cond(
                    tf.size(final_labels) > 0, f1,
                    lambda: tf.zeros([0, 14, 14]))
                tf.identity(final_masks, name='final_masks')
Exemple #9
0
    def _build_graph(self, inputs):
        is_training = get_current_tower_context().is_training

        if config.USE_SECOND_HEAD:
            # USE SECOND HEAD
            if config.MODE_MASK:
                if config.PROVIDE_BOXES_AS_INPUT:
                    image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, second_gt_labels, gt_masks, input_boxes = inputs
                else:
                    image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, second_gt_labels, gt_masks = inputs
            else:
                image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, second_gt_labels = inputs
        else:
            # NO SECOND HEAD
            if config.MODE_MASK:
                if config.PROVIDE_BOXES_AS_INPUT:
                    image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks, input_boxes = inputs
                else:
                    image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs
            else:
                image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
            second_gt_labels = None

        fm_anchors = self._get_anchors(image)
        image_shape2d_before_resize = tf.shape(image)[:2]
        image = self._preprocess(image)  # 1CHW
        image_shape2d = tf.shape(image)[2:]

        anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
        featuremap = pretrained_resnet_conv4(image,
                                             config.RESNET_NUM_BLOCK[:3])
        rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024,
                                                    config.NUM_ANCHOR)

        decoded_boxes = decode_bbox_target(rpn_box_logits,
                                           fm_anchors)  # fHxfWxNAx4, floatbox
        proposal_boxes, proposal_scores = generate_rpn_proposals(
            tf.reshape(decoded_boxes, [-1, 4]),
            tf.reshape(rpn_label_logits, [-1]), image_shape2d)

        if config.PROVIDE_BOXES_AS_INPUT:
            old_height, old_width = image_shape2d_before_resize[
                0], image_shape2d_before_resize[1]
            new_height, new_width = image_shape2d[0], image_shape2d[1]
            height_scale = new_height / old_height
            width_scale = new_width / old_width
            # TODO: check the order of dimensions!
            scale = tf.stack(
                [width_scale, height_scale, width_scale, height_scale], axis=0)
            proposal_boxes = input_boxes * tf.cast(scale, tf.float32)

        secondclassification_labels = None
        if is_training:
            # sample proposal boxes in training
            rcnn_sampled_boxes, rcnn_labels, fg_inds_wrt_gt, bg_inds = sample_fast_rcnn_targets(
                proposal_boxes, gt_boxes, gt_labels)
            if config.USE_SECOND_HEAD:
                secondclassification_labels = tf.stop_gradient(
                    tf.concat([
                        tf.gather(second_gt_labels, fg_inds_wrt_gt),
                        tf.zeros_like(bg_inds, dtype=tf.int64)
                    ],
                              axis=0,
                              name='second_sampled_labels'))
            boxes_on_featuremap = rcnn_sampled_boxes * (1.0 /
                                                        config.ANCHOR_STRIDE)
        else:
            # use all proposal boxes in inference
            boxes_on_featuremap = proposal_boxes * (1.0 / config.ANCHOR_STRIDE)
            rcnn_labels = rcnn_sampled_boxes = fg_inds_wrt_gt = None

        roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)

        # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
        def ff_true():
            feature_fastrcnn = resnet_conv5(
                roi_resized, config.RESNET_NUM_BLOCK[-1])  # nxcx7x7
            if config.TRAIN_HEADS_ONLY:
                feature_fastrcnn = tf.stop_gradient(feature_fastrcnn)
            fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head(
                'fastrcnn', feature_fastrcnn, config.NUM_CLASS)
            return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits

        def ff_false():
            ncls = config.NUM_CLASS
            return tf.zeros([0, 2048, 7,
                             7]), tf.zeros([0,
                                            ncls]), tf.zeros([0, ncls - 1, 4])

        feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond(
            tf.size(boxes_on_featuremap) > 0, ff_true, ff_false)

        feature_fastrcnn_pooled = tf.reduce_mean(feature_fastrcnn, axis=[2, 3])

        if config.USE_SECOND_HEAD:

            def if_true():
                secondclassification_label_logits = secondclassification_head(
                    'secondclassification', feature_fastrcnn,
                    config.SECOND_NUM_CLASS)
                return secondclassification_label_logits

            def if_false():
                ncls = config.SECOND_NUM_CLASS
                return tf.zeros([0, ncls])

            secondclassification_label_logits = tf.cond(
                tf.size(boxes_on_featuremap) > 0, if_true, if_false)
        else:
            secondclassification_label_logits = None

        if is_training:
            # rpn loss
            rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels,
                                                      anchor_boxes_encoded,
                                                      rpn_label_logits,
                                                      rpn_box_logits)

            # fastrcnn loss
            fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0),
                                            [-1])  # fg inds w.r.t all samples
            fg_sampled_boxes = tf.gather(rcnn_sampled_boxes,
                                         fg_inds_wrt_sample)

            with tf.name_scope('fg_sample_patch_viz'):
                fg_sampled_patches = crop_and_resize(
                    image, fg_sampled_boxes,
                    tf.zeros_like(fg_inds_wrt_sample, dtype=tf.int32), 300)
                fg_sampled_patches = tf.transpose(fg_sampled_patches,
                                                  [0, 2, 3, 1])
                fg_sampled_patches = tf.reverse(fg_sampled_patches,
                                                axis=[-1])  # BGR->RGB
                tf.summary.image('viz', fg_sampled_patches, max_outputs=30)

            matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)
            encoded_boxes = encode_bbox_target(
                matched_gt_boxes, fg_sampled_boxes) * tf.constant(
                    config.FASTRCNN_BBOX_REG_WEIGHTS)
            fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
                rcnn_labels, fastrcnn_label_logits, encoded_boxes,
                tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample))

            if config.USE_SECOND_HEAD:
                secondclassification_label_loss = secondclassification_losses(
                    secondclassification_labels,
                    secondclassification_label_logits)
            else:
                secondclassification_label_loss = None

            if config.MODE_MASK:
                # maskrcnn loss
                fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
                fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample)
                mask_logits = maskrcnn_head(
                    'maskrcnn', fg_feature,
                    config.NUM_CLASS)  # #fg x #cat x 14x14

                gt_masks_for_fg = tf.gather(gt_masks,
                                            fg_inds_wrt_gt)  # nfg x H x W
                target_masks_for_fg = crop_and_resize(
                    tf.expand_dims(gt_masks_for_fg, 1), fg_sampled_boxes,
                    tf.range(tf.size(fg_inds_wrt_gt)), 14)  # nfg x 1x14x14
                target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1,
                                                 'sampled_fg_mask_targets')
                mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels,
                                           target_masks_for_fg)
            else:
                mrcnn_loss = 0.0

            wd_cost = regularize_cost(
                '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
                l2_regularizer(1e-4),
                name='wd_cost')

            if config.TRAIN_HEADS_ONLY:
                # don't include the rpn loss
                if config.USE_SECOND_HEAD:
                    self.cost = tf.add_n([
                        fastrcnn_label_loss, fastrcnn_box_loss,
                        secondclassification_label_loss, mrcnn_loss, wd_cost
                    ], 'total_cost')
                else:
                    self.cost = tf.add_n([
                        fastrcnn_label_loss, fastrcnn_box_loss, mrcnn_loss,
                        wd_cost
                    ], 'total_cost')
            else:
                if config.USE_SECOND_HEAD:
                    self.cost = tf.add_n([
                        rpn_label_loss, rpn_box_loss, fastrcnn_label_loss,
                        fastrcnn_box_loss, secondclassification_label_loss,
                        mrcnn_loss, wd_cost
                    ], 'total_cost')
                else:
                    self.cost = tf.add_n([
                        rpn_label_loss, rpn_box_loss, fastrcnn_label_loss,
                        fastrcnn_box_loss, mrcnn_loss, wd_cost
                    ], 'total_cost')

            add_moving_summary(self.cost, wd_cost)
        else:
            label_probs = tf.nn.softmax(
                fastrcnn_label_logits,
                name='fastrcnn_all_probs')  # #proposal x #Class
            anchors = tf.tile(
                tf.expand_dims(proposal_boxes, 1),
                [1, config.NUM_CLASS - 1, 1])  # #proposal x #Cat x 4
            decoded_boxes = decode_bbox_target(
                fastrcnn_box_logits /
                tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors)
            decoded_boxes = clip_boxes(decoded_boxes,
                                       image_shape2d,
                                       name='fastrcnn_all_boxes')

            # indices: Nx2. Each index into (#proposal, #category)
            pred_indices, final_probs = fastrcnn_predictions(
                decoded_boxes, label_probs)
            final_probs = tf.identity(final_probs, 'final_probs')
            final_boxes = tf.gather_nd(decoded_boxes,
                                       pred_indices,
                                       name='final_boxes')

            pred_indices_only_boxes = pred_indices[:,
                                                   0]  # bugfix (?): changing this 1->0 fixes quite some issues. Why?
            final_labels = tf.add(pred_indices[:, 1], 1,
                                  name='final_labels')  # pred_indices[:, 1]
            final_posterior = tf.gather(label_probs,
                                        pred_indices_only_boxes,
                                        name='final_posterior')

            feature_fastrcnn_pooled = tf.gather(feature_fastrcnn_pooled,
                                                pred_indices_only_boxes,
                                                name="feature_fastrcnn_pooled")

            if config.USE_SECOND_HEAD:
                second_label_probs = tf.nn.softmax(
                    secondclassification_label_logits,
                    name='secondclassification_all_probs'
                )  # #proposal x #Class
                second_final_posterior = tf.gather(
                    second_label_probs,
                    pred_indices_only_boxes,
                    name='second_final_posterior')
                second_final_labels = tf.add(tf.argmax(second_final_posterior,
                                                       -1),
                                             1,
                                             name='second_final_labels')
            if config.MODE_MASK:
                # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
                def f1():
                    roi_resized = roi_align(
                        featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE),
                        14)
                    feature_maskrcnn = resnet_conv5(
                        roi_resized, config.RESNET_NUM_BLOCK[-1])
                    mask_logits = maskrcnn_head(
                        'maskrcnn', feature_maskrcnn,
                        config.NUM_CLASS)  # #result x #cat x 14x14
                    indices = tf.stack([
                        tf.range(tf.size(final_labels)),
                        tf.to_int32(final_labels) - 1
                    ],
                                       axis=1)
                    final_mask_logits = tf.gather_nd(mask_logits,
                                                     indices)  # #resultx14x14
                    return tf.sigmoid(final_mask_logits)

                final_masks = tf.cond(
                    tf.size(final_probs) > 0, f1,
                    lambda: tf.zeros([0, 14, 14]))
                tf.identity(final_masks, name='final_masks')
Exemple #10
0
    def build_graph(self, *inputs):
        num_fpn_level = len(config.ANCHOR_STRIDES_FPN)
        assert len(config.ANCHOR_SIZES) == num_fpn_level
        is_training = get_current_tower_context().is_training
        image = inputs[0]
        input_anchors = inputs[1: 1 + 2 * num_fpn_level]
        multilevel_anchor_labels = input_anchors[0::2]
        multilevel_anchor_boxes = input_anchors[1::2]
        gt_boxes, gt_labels = inputs[11], inputs[12]
        if config.MODE_MASK:
            gt_masks = inputs[-1]

        image = self.preprocess(image)     # 1CHW
        image_shape2d = tf.shape(image)[2:]     # h,w

        c2345 = resnet_fpn_backbone(image, config.RESNET_NUM_BLOCK)
        p23456 = fpn_model('fpn', c2345)

        # Multi-Level RPN Proposals
        multilevel_proposals = []
        rpn_loss_collection = []
        for lvl in range(num_fpn_level):
            rpn_label_logits, rpn_box_logits = rpn_head(
                'rpn', p23456[lvl], config.FPN_NUM_CHANNEL, len(config.ANCHOR_RATIOS))
            with tf.name_scope('FPN_lvl{}'.format(lvl + 2)):
                anchors = tf.constant(get_all_anchors_fpn()[lvl], name='rpn_anchor_lvl{}'.format(lvl + 2))
                anchors, anchor_labels, anchor_boxes = \
                    self.narrow_to_featuremap(p23456[lvl], anchors,
                                              multilevel_anchor_labels[lvl],
                                              multilevel_anchor_boxes[lvl])
                anchor_boxes_encoded = encode_bbox_target(anchor_boxes, anchors)
                pred_boxes_decoded = decode_bbox_target(rpn_box_logits, anchors)
                proposal_boxes, proposal_scores = generate_rpn_proposals(
                    tf.reshape(pred_boxes_decoded, [-1, 4]),
                    tf.reshape(rpn_label_logits, [-1]),
                    image_shape2d,
                    config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK)
                multilevel_proposals.append((proposal_boxes, proposal_scores))
                if is_training:
                    label_loss, box_loss = rpn_losses(
                        anchor_labels, anchor_boxes_encoded,
                        rpn_label_logits, rpn_box_logits)
                    rpn_loss_collection.extend([label_loss, box_loss])

        # Merge proposals from multi levels, pick top K
        proposal_boxes = tf.concat([x[0] for x in multilevel_proposals], axis=0)  # nx4
        proposal_scores = tf.concat([x[1] for x in multilevel_proposals], axis=0)  # n
        proposal_topk = tf.minimum(tf.size(proposal_scores),
                                   config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK)
        proposal_scores, topk_indices = tf.nn.top_k(proposal_scores, k=proposal_topk, sorted=False)
        proposal_boxes = tf.gather(proposal_boxes, topk_indices)

        if is_training:
            rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
                proposal_boxes, gt_boxes, gt_labels)
        else:
            # The boxes to be used to crop RoIs.
            rcnn_boxes = proposal_boxes

        roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7)

        fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_2fc_head(
            'fastrcnn', roi_feature_fastrcnn, config.NUM_CLASS)

        if is_training:
            # rpn loss is already defined above
            with tf.name_scope('rpn_losses'):
                rpn_total_label_loss = tf.add_n(rpn_loss_collection[::2], name='label_loss')
                rpn_total_box_loss = tf.add_n(rpn_loss_collection[1::2], name='box_loss')
                add_moving_summary(rpn_total_box_loss, rpn_total_label_loss)

            # fastrcnn loss:
            matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)

            fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1])   # fg inds w.r.t all samples
            fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample)
            fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample)

            fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training(
                image, rcnn_labels, fg_sampled_boxes,
                matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits)

            if config.MODE_MASK:
                # maskrcnn loss
                fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
                roi_feature_maskrcnn = multilevel_roi_align(
                    p23456[:4], fg_sampled_boxes, 14)
                mask_logits = maskrcnn_upXconv_head(
                    'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4)   # #fg x #cat x 28 x 28

                target_masks_for_fg = crop_and_resize(
                    tf.expand_dims(gt_masks, 1),
                    fg_sampled_boxes,
                    fg_inds_wrt_gt, 28,
                    pad_border=False)  # fg x 1x28x28
                target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
                mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg)
            else:
                mrcnn_loss = 0.0

            wd_cost = regularize_cost(
                '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
                l2_regularizer(1e-4), name='wd_cost')

            total_cost = tf.add_n(rpn_loss_collection + [
                fastrcnn_label_loss, fastrcnn_box_loss,
                mrcnn_loss, wd_cost], 'total_cost')

            add_moving_summary(total_cost, wd_cost)
            return total_cost
        else:
            final_boxes, final_labels = self.fastrcnn_inference(
                image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
            if config.MODE_MASK:
                # Cascade inference needs roi transform with refined boxes.
                roi_feature_maskrcnn = multilevel_roi_align(
                    p23456[:4], final_boxes, 14)
                mask_logits = maskrcnn_upXconv_head(
                    'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4)   # #fg x #cat x 28 x 28
                indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
                final_mask_logits = tf.gather_nd(mask_logits, indices)   # #resultx28x28
                tf.sigmoid(final_mask_logits, name='final_masks')
Exemple #11
0
    def build_graph(self, *inputs):
        is_training = get_current_tower_context().is_training
        if config.MODE_MASK:
            image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs
        else:
            image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
        image = self.preprocess(image)     # 1CHW

        featuremap = resnet_c4_backbone(image, config.RESNET_NUM_BLOCK[:3])
        rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR)

        fm_anchors, anchor_labels, anchor_boxes = self.narrow_to_featuremap(
            featuremap, get_all_anchors(), anchor_labels, anchor_boxes)
        anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)

        image_shape2d = tf.shape(image)[2:]     # h,w
        pred_boxes_decoded = decode_bbox_target(rpn_box_logits, fm_anchors)  # fHxfWxNAx4, floatbox
        proposal_boxes, proposal_scores = generate_rpn_proposals(
            tf.reshape(pred_boxes_decoded, [-1, 4]),
            tf.reshape(rpn_label_logits, [-1]),
            image_shape2d,
            config.TRAIN_PRE_NMS_TOPK if is_training else config.TEST_PRE_NMS_TOPK,
            config.TRAIN_POST_NMS_TOPK if is_training else config.TEST_POST_NMS_TOPK)

        if is_training:
            # sample proposal boxes in training
            rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
                proposal_boxes, gt_boxes, gt_labels)
        else:
            # The boxes to be used to crop RoIs.
            # Use all proposal boxes in inference
            rcnn_boxes = proposal_boxes

        boxes_on_featuremap = rcnn_boxes * (1.0 / config.ANCHOR_STRIDE)
        roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)

        # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
        # which was fixed in TF 1.6
        def ff_true():
            feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1])    # nxcx7x7
            feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first')
            fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, config.NUM_CLASS)
            # Return C5 feature to be shared with mask branch
            return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits

        def ff_false():
            ncls = config.NUM_CLASS
            return tf.zeros([0, 2048, 7, 7]), tf.zeros([0, ncls]), tf.zeros([0, ncls - 1, 4])

        if get_tf_version_number() >= 1.6:
            feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = ff_true()
        else:
            logger.warn("This example may drop support for TF < 1.6 soon.")
            feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond(
                tf.size(boxes_on_featuremap) > 0, ff_true, ff_false)

        if is_training:
            # rpn loss
            rpn_label_loss, rpn_box_loss = rpn_losses(
                anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits)

            # fastrcnn loss
            matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)

            fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1])   # fg inds w.r.t all samples
            fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample)
            fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample)

            fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training(
                image, rcnn_labels, fg_sampled_boxes,
                matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits)

            if config.MODE_MASK:
                # maskrcnn loss
                fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
                # In training, mask branch shares the same C5 feature.
                fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample)
                mask_logits = maskrcnn_upXconv_head(
                    'maskrcnn', fg_feature, config.NUM_CLASS, num_convs=0)   # #fg x #cat x 14x14

                target_masks_for_fg = crop_and_resize(
                    tf.expand_dims(gt_masks, 1),
                    fg_sampled_boxes,
                    fg_inds_wrt_gt, 14,
                    pad_border=False)  # nfg x 1x14x14
                target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
                mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg)
            else:
                mrcnn_loss = 0.0

            wd_cost = regularize_cost(
                '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
                l2_regularizer(1e-4), name='wd_cost')

            total_cost = tf.add_n([
                rpn_label_loss, rpn_box_loss,
                fastrcnn_label_loss, fastrcnn_box_loss,
                mrcnn_loss,
                wd_cost], 'total_cost')

            add_moving_summary(total_cost, wd_cost)
            return total_cost
        else:
            final_boxes, final_labels = self.fastrcnn_inference(
                image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)

            if config.MODE_MASK:
                # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
                def f1():
                    roi_resized = roi_align(featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE), 14)
                    feature_maskrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1])
                    mask_logits = maskrcnn_upXconv_head(
                        'maskrcnn', feature_maskrcnn, config.NUM_CLASS, 0)   # #result x #cat x 14x14
                    indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
                    final_mask_logits = tf.gather_nd(mask_logits, indices)   # #resultx14x14
                    return tf.sigmoid(final_mask_logits)

                final_masks = tf.cond(tf.size(final_labels) > 0, f1, lambda: tf.zeros([0, 14, 14]))
                tf.identity(final_masks, name='final_masks')
Exemple #12
0
    def _build_graph(self, inputs):
        is_training = get_current_tower_context().is_training
        image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
        fm_anchors = self._get_anchors(image)
        image = self._preprocess(image)

        anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
        featuremap = pretrained_resnet_conv4(image,
                                             config.RESNET_NUM_BLOCK[:3])
        rpn_label_logits, rpn_box_logits = rpn_head(featuremap, 1024,
                                                    config.NUM_ANCHOR)
        rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels,
                                                  anchor_boxes_encoded,
                                                  rpn_label_logits,
                                                  rpn_box_logits)

        decoded_boxes = decode_bbox_target(
            rpn_box_logits, fm_anchors,
            config.ANCHOR_STRIDE)  # (fHxfWxNA)x4, floatbox
        proposal_boxes, proposal_scores = generate_rpn_proposals(
            decoded_boxes, tf.reshape(rpn_label_logits, [-1]),
            tf.shape(image)[2:])

        if is_training:
            rcnn_sampled_boxes, rcnn_encoded_boxes, rcnn_labels = sample_fast_rcnn_targets(
                proposal_boxes, gt_boxes, gt_labels)
            boxes_on_featuremap = rcnn_sampled_boxes * (1.0 /
                                                        config.ANCHOR_STRIDE)
            roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
            feature_fastrcnn = resnet_conv5_gap(
                roi_resized, config.RESNET_NUM_BLOCK[-1])  # nxc
            fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head(
                feature_fastrcnn, config.NUM_CLASS)

            fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
                rcnn_labels, rcnn_encoded_boxes, fastrcnn_label_logits,
                fastrcnn_box_logits)

            wd_cost = regularize_cost(
                '(?:group1|group2|group3|rpn|fastrcnn)/.*W',
                l2_regularizer(1e-4),
                name='wd_cost')

            self.cost = tf.add_n([
                rpn_label_loss, rpn_box_loss, fastrcnn_label_loss,
                fastrcnn_box_loss, wd_cost
            ], 'total_cost')

            for k in self.cost, wd_cost:
                add_moving_summary(k)
        else:
            roi_resized = roi_align(
                featuremap, proposal_boxes * (1.0 / config.ANCHOR_STRIDE), 14)
            feature_fastrcnn = resnet_conv5_gap(
                roi_resized, config.RESNET_NUM_BLOCK[-1])  # nxc
            label_logits, fastrcnn_box_logits = fastrcnn_head(
                feature_fastrcnn, config.NUM_CLASS)
            label_probs = tf.nn.softmax(label_logits,
                                        name='fastrcnn_all_probs')  # NP,
            labels = tf.argmax(label_logits, axis=1)
            fg_ind, fg_box_logits = fastrcnn_predict_boxes(
                labels, fastrcnn_box_logits)
            fg_label_probs = tf.gather(label_probs,
                                       fg_ind,
                                       name='fastrcnn_fg_probs')
            fg_boxes = tf.gather(proposal_boxes, fg_ind)

            fg_box_logits = fg_box_logits / tf.constant(
                config.FASTRCNN_BBOX_REG_WEIGHTS)
            decoded_boxes = decode_bbox_target(
                fg_box_logits, fg_boxes,
                config.ANCHOR_STRIDE)  # Nfx4, floatbox
            decoded_boxes = tf.identity(decoded_boxes,
                                        name='fastrcnn_fg_boxes')