Пример #1
0
    def init_modules(self):
        self.feature_extractor = ResNetFeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        self.rcnn_cls_pred = nn.Linear(2048, self.n_classes)
        if self.reduce:
            in_channels = 2048
        else:
            in_channels = 2048 * 4 * 4
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4)
        else:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2)
        else:
            self.rcnn_cls_loss = functools.partial(F.cross_entropy,
                                                   reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)
Пример #2
0
    def init_modules(self):
        self.feature_extractor = feature_extractors_builder.build(
            self.feature_extractor_config)
        # self.feature_extractor = ResNetFeatureExtractor(
        # self.feature_extractor_config)
        # self.feature_extractor = MobileNetFeatureExtractor(
        # self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        self.rcnn_cls_pred = nn.Linear(self.ndin, self.n_classes)
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(self.ndin, 4)
            # self.rcnn_bbox_pred = nn.Conv2d(2048,4,3,1,1)
        else:
            self.rcnn_bbox_pred = nn.Linear(self.ndin, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2,
                                           gamma=2,
                                           alpha=0.2,
                                           auto_alpha=False)
        else:
            self.rcnn_cls_loss = functools.partial(F.cross_entropy,
                                                   reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

        # attention
        if self.use_self_attention:
            self.spatial_attention = nn.Conv2d(self.ndin, 1, 3, 1, 1)

        self.rcnn_pooling2 = RoIAlignAvg(self.pooling_size, self.pooling_size,
                                         1.0 / 8.0)
        self.reduce_pooling = nn.Sequential(nn.Conv2d(512, 1024, 1, 1, 0),
                                            nn.ReLU())
Пример #3
0
    def init_modules(self):
        self.feature_extractor = ResNetFeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        self.rcnn_cls_pred = nn.Linear(2048, 2)
        if self.reduce:
            in_channels = 2048
        else:
            in_channels = 2048 * 4 * 4
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4)
        else:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4 * self.n_classes)

        if self.multiple_crop:
            self.rcnn_pooling2 = RoIAlignAvg(self.pooling_size,
                                             self.pooling_size, 1.0 / 8.0)
            #  1x1 fusion
            self.pooled_feat_fusion = nn.Conv2d(512, 1024, 1, 1, 0)

        # loss module
        #  if self.enable_cls:
        #  if self.use_focal_loss:
        #  self.rcnn_cls_loss = FocalLoss(2, alpha=0.25, gamma=2)
        #  else:
        #  self.rcnn_cls_loss = functools.partial(
        #  F.cross_entropy, reduce=False)
        #  elif self.enable_iou:
        self.rcnn_cls_loss = nn.MSELoss(reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)
Пример #4
0
class DoubleIoUFasterRCNN(Model):
    def forward(self, feed_dict):

        prediction_dict = {}

        # base model
        base_feat = self.feature_extractor.first_stage_feature(
            feed_dict['img'])
        feed_dict.update({'base_feat': base_feat})
        self.add_feat('base_feat', base_feat)
        # batch_size = base_feat.shape[0]

        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))

        # proposals = prediction_dict['proposals_batch']
        # shape(N,num_proposals,5)
        # pre subsample for reduce consume of memory
        if self.training:
            self.pre_subsample(prediction_dict, feed_dict)
        rois_batch = prediction_dict['rois_batch']

        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        pooled_feat = self.rcnn_pooling(base_feat, rois_batch.view(-1, 5))

        # shape(N,C,1,1)
        pooled_feat_reg = self.feature_extractor.second_stage_feature(
            pooled_feat)
        # shape(N,C)
        # if self.reduce:
        pooled_feat_reg = pooled_feat_reg.mean(3).mean(2)
        # else:
        # pooled_feat = pooled_feat.view(self.rcnn_batch_size, -1)

        rcnn_bbox_preds = self.rcnn_bbox_pred(pooled_feat_reg)
        # rcnn_cls_scores = self.rcnn_cls_pred(pooled_feat)

        # rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)

        # prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds
        # prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores

        # used for track
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][
            proposals_order]
        prediction_dict['rcnn_cls_probs'] = prediction_dict['rpn_cls_probs'][
            0][proposals_order]

        return prediction_dict

    # def rcnn_cls_pred(pooled_feat)

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)

    def init_modules(self):
        self.feature_extractor = ResNetFeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        self.rcnn_cls_pred = nn.Linear(2048, self.n_classes)
        if self.reduce:
            in_channels = 2048
        else:
            in_channels = 2048 * 4 * 4
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4)
        else:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2)
        else:
            self.rcnn_cls_loss = functools.partial(F.cross_entropy,
                                                   reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']

        # some submodule config
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])

        # self.reduce = model_config.get('reduce')
        self.reduce = True

    def pre_subsample(self, prediction_dict, feed_dict):
        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        cls_criterion = None
        pos_indicator = rcnn_reg_weights > 0
        # indicator = rcnn_cls_weights > 0
        indicator = None

        # subsample from all
        # shape (N,M)
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        # num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        # check
        # assert num_cls_coeff, 'bug happens'
        assert num_reg_coeff, 'bug happens'

        # prediction_dict[
        # 'rcnn_cls_weights'] = rcnn_cls_weights / num_cls_coeff.float()
        prediction_dict[
            'rcnn_reg_weights'] = rcnn_reg_weights / num_reg_coeff.float()
        # prediction_dict['rcnn_cls_targets'] = rcnn_cls_targets[
        # batch_sampled_mask]
        prediction_dict['rcnn_reg_targets'] = rcnn_reg_targets[
            batch_sampled_mask]

        # update rois_batch
        prediction_dict['rois_batch'] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)

        if not self.training:
            # used for track
            proposals_order = prediction_dict['proposals_order']

            prediction_dict['proposals_order'] = proposals_order[
                batch_sampled_mask]

    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        # targets and weights
        # rcnn_cls_weights = prediction_dict['rcnn_cls_weights']
        rcnn_reg_weights = prediction_dict['rcnn_reg_weights']

        # rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
        rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

        # classification loss
        # rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
        # rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores, rcnn_cls_targets)
        # rcnn_cls_loss *= rcnn_cls_weights
        # rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

        # bounding box regression L1 loss
        rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                             rcnn_reg_targets).sum(dim=-1)
        rcnn_bbox_loss *= rcnn_reg_weights
        # rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

        # loss weights has no gradients
        # loss_dict['rcnn_cls_loss'] = rcnn_cls_loss
        loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        # add rcnn_cls_targets to get the statics of rpn
        # loss_dict['rcnn_cls_targets'] = rcnn_cls_targets

        return loss_dict
class DetachDoubleIOUFasterRCNN(Model):
    def forward(self, feed_dict):

        prediction_dict = {}

        # base model
        base_feat = self.feature_extractor.first_stage_feature(
            feed_dict['img'])
        feed_dict.update({'base_feat': base_feat})
        self.add_feat('base_feat', base_feat)

        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))

        # pre subsample for reduce consume of memory
        if self.training:
            self.pre_subsample(prediction_dict, feed_dict)
        rois_batch = prediction_dict['rois_batch']

        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        pooled_feat = self.rcnn_pooling(base_feat, rois_batch.view(-1, 5))

        # although it must be true
        #  if self.enable_reg:
        # shape(N,C,1,1)
        pooled_feat_reg = self.feature_extractor.second_stage_feature(
            pooled_feat)

        pooled_feat_reg = pooled_feat_reg.mean(3).mean(2)
        rcnn_bbox_preds = self.rcnn_bbox_pred(pooled_feat_reg)
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds

        if self.enable_cls:
            pooled_feat_cls = self.feature_extractor.third_stage_feature(
                pooled_feat.detach())

            # shape(N,C)
            pooled_feat_cls = pooled_feat_cls.mean(3).mean(2)
            rcnn_cls_scores = self.rcnn_cls_pred(pooled_feat_cls)
            rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)

            prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs
            prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores

        # used for track
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][
            proposals_order]
        prediction_dict['second_rpn_cls_probs'] = prediction_dict[
            'rpn_cls_probs'][0][proposals_order]

        if not self.training and self.enable_iou:
            # calculate fake iou as final score,of course use scores to filter bg
            pred_boxes = self.bbox_coder.decode_batch(
                rcnn_bbox_preds.view(1, -1, 4), rois_batch[:, :, 1:5])
            iou_matrix = box_ops.iou(pred_boxes, rois_batch[:, :, 1:5])[0]
            iou_matrix[rcnn_cls_probs[:, 1] < 0.5] = 0
            rcnn_cls_probs[:, 1] = iou_matrix
            prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs

        if not self.training and self.enable_track_rois:
            self.target_assigner.assign(rois_batch[:, :,
                                                   1:], feed_dict['gt_boxes'],
                                        feed_dict['gt_labels'])

        return prediction_dict

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)

    def init_modules(self):
        self.feature_extractor = ResNetFeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        self.rcnn_cls_pred = nn.Linear(2048, self.n_classes)
        if self.reduce:
            in_channels = 2048
        else:
            in_channels = 2048 * 4 * 4
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4)
        else:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2, alpha=0.25)
        else:
            self.rcnn_cls_loss = functools.partial(F.cross_entropy,
                                                   reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']

        # some submodule config
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # bbox_coder
        self.bbox_coder = self.target_assigner.bbox_coder

        # similarity
        self.similarity_calc = self.target_assigner.similarity_calc

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])

        # self.reduce = model_config.get('reduce')
        self.reduce = True

        # optimize cls
        self.enable_cls = False

        # optimize reg
        self.enable_reg = True

        # cal iou
        self.enable_iou = False

        # track good rois
        self.enable_track_rois = False

    def pre_subsample(self, prediction_dict, feed_dict):
        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        cls_criterion = None

        if self.enable_reg:
            # used for reg training
            pos_indicator = rcnn_reg_weights > 0
            indicator = None
        elif self.enable_cls:
            # used for cls training
            pos_indicator = rcnn_cls_targets > 0
            indicator = rcnn_cls_weights > 0
        else:
            raise ValueError("please check enable reg and enable cls again")

        # subsample from all
        # shape (N,M)
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)

        if self.enable_cls:
            rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
            num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
            assert num_cls_coeff, 'bug happens'
            prediction_dict[
                'rcnn_cls_weights'] = rcnn_cls_weights / num_cls_coeff.float()

        # used for retriving statistic
        prediction_dict['rcnn_cls_targets'] = rcnn_cls_targets[
            batch_sampled_mask]

        # used for fg/bg
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        assert num_reg_coeff, 'bug happens'
        prediction_dict[
            'rcnn_reg_weights'] = rcnn_reg_weights / num_reg_coeff.float()

        if self.enable_reg:
            prediction_dict['rcnn_reg_targets'] = rcnn_reg_targets[
                batch_sampled_mask]

        prediction_dict['fake_match'] = self.target_assigner.analyzer.match[
            batch_sampled_mask]

        # update rois_batch
        prediction_dict['rois_batch'] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)

    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        # add rcnn_cls_targets to get the statics of rpn
        #  loss_dict['rcnn_cls_targets'] = rcnn_cls_targets

        if self.enable_cls:
            # targets and weights
            rcnn_cls_weights = prediction_dict['rcnn_cls_weights']

            rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
            # classification loss
            rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
            rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores,
                                               rcnn_cls_targets)
            rcnn_cls_loss *= rcnn_cls_weights
            rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

            loss_dict['rcnn_cls_loss'] = rcnn_cls_loss

        if self.enable_reg:
            rcnn_reg_weights = prediction_dict['rcnn_reg_weights']
            rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

            # bounding box regression L1 loss
            rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
            rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                                 rcnn_reg_targets).sum(dim=-1)
            rcnn_bbox_loss *= rcnn_reg_weights
            rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

            # loss weights has no gradients
            loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        # analysis ap
        # when enable cls,otherwise it is no sense
        if self.enable_cls:
            rcnn_cls_probs = prediction_dict['rcnn_cls_probs']
            num_gt = feed_dict['gt_labels'].numel()
            fake_match = prediction_dict['fake_match']
            self.target_assigner.analyzer.analyze_ap(fake_match,
                                                     rcnn_cls_probs[:, 1],
                                                     num_gt,
                                                     thresh=0.5)

        return loss_dict
Пример #6
0
class PostCLSFasterRCNN(Model):
    def forward(self, feed_dict):
        # some pre forward hook
        self.clean_stats()

        prediction_dict = {}

        ################################
        # first stage
        ################################
        # base model
        base_feat = self.feature_extractor.first_stage_feature(
            feed_dict['img'])
        feed_dict.update({'base_feat': base_feat})
        self.add_feat('base_feat', base_feat)

        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))

        #####################################
        # second stage(bbox regression)
        #####################################
        # pre subsample for reduce consume of memory
        if self.training and self.enable_reg:
            # append gt
            if self.use_gt:
                prediction_dict['rois_batch'] = self.append_gt(
                    prediction_dict['rois_batch'], feed_dict['gt_boxes'])
            stats = self.pre_subsample(prediction_dict, feed_dict)
            # rois stats
            self.stats.update(stats)
        rois_batch = prediction_dict['rois_batch']

        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        pooled_feat = self.rcnn_pooling(base_feat, rois_batch.view(-1, 5))

        # although it must be true
        #  if self.enable_reg:
        # shape(N,C,1,1)
        pooled_feat_reg = self.feature_extractor.second_stage_feature(
            pooled_feat)

        pooled_feat_reg = pooled_feat_reg.mean(3).mean(2)
        rcnn_bbox_preds = self.rcnn_bbox_pred(pooled_feat_reg)
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds

        # used for tracking
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][
            proposals_order]
        prediction_dict['second_rpn_cls_probs'] = prediction_dict[
            'rpn_cls_probs'][0][proposals_order]

        ###########################################
        # third stage(predict scores of final bbox)
        ###########################################

        # decode rcnn bbox, generate rcnn rois batch
        pred_boxes = self.bbox_coder.decode_batch(
            rcnn_bbox_preds.view(1, -1, 4), rois_batch[:, :, 1:5])
        rcnn_rois_batch = torch.zeros_like(rois_batch)
        rcnn_rois_batch[:, :, 1:5] = pred_boxes.detach()
        prediction_dict['rcnn_rois_batch'] = rcnn_rois_batch

        if self.training and self.use_gt:
            # append gt
            rcnn_rois_batch = self.append_gt(rcnn_rois_batch,
                                             feed_dict['gt_boxes'])
            prediction_dict['rcnn_rois_batch'] = rcnn_rois_batch

        if self.enable_cls:
            if self.training:
                rcnn_stats = self.pre_subsample(prediction_dict,
                                                feed_dict,
                                                stage='rcnn')
                # rcnn stats
                self.rcnn_stats.update(rcnn_stats)

            # rois after subsample
            pred_rois = prediction_dict['rcnn_rois_batch']
            pooled_feat_cls = self.rcnn_pooling(base_feat,
                                                pred_rois.view(-1, 5))
            pooled_feat_cls = self.feature_extractor.third_stage_feature(
                pooled_feat_cls.detach())

            # shape(N,C)
            pooled_feat_cls = pooled_feat_cls.mean(3).mean(2)
            rcnn_cls_scores = self.rcnn_cls_pred(pooled_feat_cls)
            rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)

            prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs
            prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores

        ###################################
        # stats
        ###################################
        # import ipdb
        # ipdb.set_trace()
        if not self.training or (self.enable_track_rois
                                 and not self.enable_reg):
            # when enable reg, skip it,
            stats = self.target_assigner.assign(rois_batch[:, :, 1:],
                                                feed_dict['gt_boxes'],
                                                feed_dict['gt_labels'])[-1]
            self.stats.update(stats)

        if not self.training or (self.enable_track_rcnn_rois
                                 and not self.enable_cls):
            # when enable cls, skip it
            stats = self.target_assigner.assign(rcnn_rois_batch[:, :, 1:],
                                                feed_dict['gt_boxes'],
                                                feed_dict['gt_labels'])[-1]
            self.rcnn_stats.update(stats)

        # analysis ap
        # when enable cls, otherwise it is no sense
        if self.training and self.enable_cls:
            rcnn_cls_probs = prediction_dict['rcnn_cls_probs']
            num_gt = feed_dict['gt_labels'].numel()
            fake_match = self.rcnn_stats['match']
            stats = self.target_assigner.analyzer.analyze_ap(fake_match,
                                                             rcnn_cls_probs[:,
                                                                            1],
                                                             num_gt,
                                                             thresh=0.5)
            # collect stats
            self.rcnn_stats.update(stats)

        return prediction_dict

    def append_gt(self, rois_batch, gt_boxes):
        ################################
        # append gt_boxes to rois_batch for losses
        ################################
        # may be some bugs here
        gt_boxes_append = torch.zeros(gt_boxes.shape[0], gt_boxes.shape[1],
                                      5).type_as(gt_boxes)
        gt_boxes_append[:, :, 1:5] = gt_boxes[:, :, :4]
        # cat gt_boxes to rois_batch
        rois_batch = torch.cat([rois_batch, gt_boxes_append], dim=1)
        return rois_batch

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)

    def init_modules(self):
        self.feature_extractor = ResNetFeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        self.rcnn_cls_pred = nn.Linear(2048, self.n_classes)
        if self.reduce:
            in_channels = 2048
        else:
            in_channels = 2048 * 4 * 4
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4)
        else:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2, alpha=0.25, gamma=2)
        else:
            self.rcnn_cls_loss = functools.partial(F.cross_entropy,
                                                   reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']

        # some submodule config
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # bbox_coder
        self.bbox_coder = self.target_assigner.bbox_coder

        # similarity
        self.similarity_calc = self.target_assigner.similarity_calc

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])

        # self.reduce = model_config.get('reduce')
        self.reduce = True

        # optimize cls
        self.enable_cls = True

        # optimize reg
        self.enable_reg = False

        # cal iou
        self.enable_iou = False

        # track good rois
        self.enable_track_rois = True
        self.enable_track_rcnn_rois = True

        # eval the final bbox
        self.enable_eval_final_bbox = True

        # use gt
        self.use_gt = False

        # if self.enable_eval_final_bbox:
        self.subsample = False

    def clean_stats(self):
        # rois bbox
        self.stats = {
            'num_det': 1,
            'num_tp': 0,
            'matched_thresh': 0,
            'recall_thresh': 0,
            'match': None
        }

        # rcnn bbox(final bbox)
        self.rcnn_stats = {
            'num_det': 1,
            'num_tp': 0,
            'matched_thresh': 0,
            'recall_thresh': 0,
            'match': None
        }

    def pre_subsample(self, prediction_dict, feed_dict, stage='rpn'):
        if stage == 'rpn':
            rois_name = 'rois_batch'
        else:
            rois_name = 'rcnn_rois_batch'

        rois_batch = prediction_dict[rois_name]

        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        # append gt
        # rois_batch = self.append_gt(rois_batch, gt_boxes)

        ##########################
        # assigner
        ##########################
        # import ipdb
        # ipdb.set_trace()
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights, stats = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        if self.subsample:
            cls_criterion = None

            if self.enable_reg:
                # used for reg training
                pos_indicator = rcnn_reg_weights > 0
                indicator = None
            elif self.enable_cls:
                # used for cls training
                pos_indicator = rcnn_cls_targets > 0
                indicator = rcnn_cls_weights > 0
            else:
                raise ValueError(
                    "please check enable reg and enable cls again")

            # subsample from all
            # shape (N,M)
            batch_sampled_mask = self.sampler.subsample_batch(
                self.rcnn_batch_size,
                pos_indicator,
                indicator=indicator,
                criterion=cls_criterion)
        else:
            batch_sampled_mask = torch.ones_like(rcnn_cls_weights > 0)

        if self.enable_cls:
            rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
            num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
            assert num_cls_coeff, 'bug happens'
            prediction_dict[
                'rcnn_cls_weights'] = rcnn_cls_weights / num_cls_coeff.float()

        # used for retriving statistic
        prediction_dict['rcnn_cls_targets'] = rcnn_cls_targets[
            batch_sampled_mask]

        # used for fg/bg
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        num_reg_coeff = torch.max(num_reg_coeff,
                                  torch.ones_like(num_reg_coeff))
        # import ipdb
        # ipdb.set_trace()
        # assert num_reg_coeff, 'bug happens'
        prediction_dict[
            'rcnn_reg_weights'] = rcnn_reg_weights / num_reg_coeff.float()

        if self.enable_reg:
            prediction_dict['rcnn_reg_targets'] = rcnn_reg_targets[
                batch_sampled_mask]

        # here use rcnn_target_assigner for final bbox pred
        stats['match'] = stats['match'][batch_sampled_mask]

        # update rois_batch
        prediction_dict[rois_name] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)
        return stats

    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss

        # add rcnn_cls_targets to get the statics of rpn
        #  loss_dict['rcnn_cls_targets'] = rcnn_cls_targets

        if self.enable_cls:
            # targets and weights
            rcnn_cls_weights = prediction_dict['rcnn_cls_weights']

            rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
            # classification loss
            rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
            rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores,
                                               rcnn_cls_targets)
            rcnn_cls_loss *= rcnn_cls_weights
            rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

            loss_dict['rcnn_cls_loss'] = rcnn_cls_loss

        if self.enable_reg:
            loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

            rcnn_reg_weights = prediction_dict['rcnn_reg_weights']
            rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

            # bounding box regression L1 loss
            rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
            rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                                 rcnn_reg_targets).sum(dim=-1)
            rcnn_bbox_loss *= rcnn_reg_weights
            rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

            # loss weights has no gradients
            loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        return loss_dict
Пример #7
0
class DetachFasterRCNN(Model):
    def collect_intermedia_layers(self, img):
        feat2 = self.feature_extractor.first_stage_feature[:-1](img)
        feat3 = self.feature_extractor.first_stage_feature[-1](feat2)

        end_points = {'feat2': feat2, 'feat3': feat3}
        return feat3, end_points

    def caroi_pooling(self, all_feats, rois_batch, out_channels):
        # import ipdb
        # ipdb.set_trace()
        pooled_feats = []
        for feat_name in all_feats:
            if feat_name == 'feat2':
                pooled_feat = self.rcnn_pooling2(all_feats[feat_name],
                                                 rois_batch)
            else:
                continue
            # else:
            # pooled_feat = self.rcnn_pooling(all_feats[feat_name],
            # rois_batch)
            pooled_feats.append(pooled_feat)
        pooled_feats = torch.cat(pooled_feats, dim=1)
        if pooled_feats.shape[1] != out_channels:
            # add 1x1 conv
            pooled_feats = self.reduce_pooling(pooled_feats)
        return pooled_feats

    def forward(self, feed_dict):

        prediction_dict = {}

        # base model
        base_feat, all_feats = self.collect_intermedia_layers(feed_dict['img'])
        feed_dict.update({'base_feat': base_feat})
        self.add_feat('feat2', all_feats['feat2'])
        self.add_feat('base_feat', base_feat)
        # batch_size = base_feat.shape[0]

        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))

        # proposals = prediction_dict['proposals_batch']
        # shape(N,num_proposals,5)
        # pre subsample for reduce consume of memory
        if self.training:
            self.pre_subsample(prediction_dict, feed_dict)
        rois_batch = prediction_dict['rois_batch']

        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        # pooled_feat = self.caroi_pooling(
        # all_feats, rois_batch.view(-1, 5), out_channels=1024)
        pooled_feat = self.rcnn_pooling(base_feat, rois_batch.view(-1, 5))

        # shape(N,C,1,1)
        # pooled_feat = self.feature_extractor.second_stage_feature(pooled_feat)
        pooled_feat_reg = self.feature_extractor.second_stage_feature(
            pooled_feat)
        pooled_feat_cls = self.feature_extractor.third_stage_feature(
            pooled_feat.detach())
        self.add_feat('base_feat', base_feat)

        pooled_feat_cls = pooled_feat_cls.mean(3).mean(2)
        pooled_feat_reg = pooled_feat_reg.mean(3).mean(2)
        rcnn_cls_scores = self.rcnn_cls_pred(pooled_feat_cls)
        rcnn_bbox_preds = self.rcnn_bbox_pred(pooled_feat_reg)
        rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)

        prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds
        prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores

        # used for track
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][
            proposals_order]

        return prediction_dict

    def generate_channel_attention(self, feat):
        return feat.mean(3, keepdim=True).mean(2, keepdim=True)

    def generate_spatial_attention(self, feat):
        return self.spatial_attention(feat)

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)

    def init_modules(self):
        self.feature_extractor = feature_extractors_builder.build(
            self.feature_extractor_config)
        # self.feature_extractor = ResNetFeatureExtractor(
        # self.feature_extractor_config)
        # self.feature_extractor = MobileNetFeatureExtractor(
        # self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        self.rcnn_cls_pred = nn.Linear(self.ndin, self.n_classes)
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(self.ndin, 4)
            # self.rcnn_bbox_pred = nn.Conv2d(2048,4,3,1,1)
        else:
            self.rcnn_bbox_pred = nn.Linear(self.ndin, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2,
                                           gamma=2,
                                           alpha=0.2,
                                           auto_alpha=False)
        else:
            self.rcnn_cls_loss = functools.partial(F.cross_entropy,
                                                   reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

        # attention
        if self.use_self_attention:
            self.spatial_attention = nn.Conv2d(self.ndin, 1, 3, 1, 1)

        self.rcnn_pooling2 = RoIAlignAvg(self.pooling_size, self.pooling_size,
                                         1.0 / 8.0)
        self.reduce_pooling = nn.Sequential(nn.Conv2d(512, 1024, 1, 1, 0),
                                            nn.ReLU())

    def init_param(self, model_config):
        if model_config.get('din'):
            self.ndin = model_config['din']
        else:
            self.ndin = 512
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']
        self.use_self_attention = model_config.get('use_self_attention')

        # some submodule config
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])

    def pre_subsample(self, prediction_dict, feed_dict):
        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        cls_criterion = None
        pos_indicator = rcnn_reg_weights > 0
        indicator = rcnn_cls_weights > 0

        # subsample from all
        # shape (N,M)
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        # check
        assert num_cls_coeff, 'bug happens'
        assert num_reg_coeff, 'bug happens'

        prediction_dict[
            'rcnn_cls_weights'] = rcnn_cls_weights / num_cls_coeff.float()
        prediction_dict[
            'rcnn_reg_weights'] = rcnn_reg_weights / num_reg_coeff.float()
        prediction_dict['rcnn_cls_targets'] = rcnn_cls_targets[
            batch_sampled_mask]
        prediction_dict['rcnn_reg_targets'] = rcnn_reg_targets[
            batch_sampled_mask]
        prediction_dict['fake_match'] = self.target_assigner.analyzer.match[
            batch_sampled_mask]

        # update rois_batch
        prediction_dict['rois_batch'] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)

        if not self.training:
            # used for track
            proposals_order = prediction_dict['proposals_order']

            prediction_dict['proposals_order'] = proposals_order[
                batch_sampled_mask]

    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        # targets and weights
        rcnn_cls_weights = prediction_dict['rcnn_cls_weights']
        rcnn_reg_weights = prediction_dict['rcnn_reg_weights']

        rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
        rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

        # classification loss
        rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
        rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores, rcnn_cls_targets)
        rcnn_cls_loss *= rcnn_cls_weights
        rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

        # bounding box regression L1 loss
        rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                             rcnn_reg_targets).sum(dim=-1)
        rcnn_bbox_loss *= rcnn_reg_weights
        # rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

        # loss weights has no gradients
        loss_dict['rcnn_cls_loss'] = rcnn_cls_loss
        loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        # add rcnn_cls_targets to get the statics of rpn
        # loss_dict['rcnn_cls_targets'] = rcnn_cls_targets

        # analysis ap
        rcnn_cls_probs = prediction_dict['rcnn_cls_probs']
        num_gt = feed_dict['gt_labels'].numel()
        fake_match = prediction_dict['fake_match']
        self.target_assigner.analyzer.analyze_ap(fake_match,
                                                 rcnn_cls_probs[:, 1],
                                                 num_gt,
                                                 thresh=0.5)

        return loss_dict
class DoubleIoUSecondStageFasterRCNN(Model):
    def forward(self, feed_dict):
        # import ipdb
        # ipdb.set_trace()
        # self.visualizer.visualize(
        # feed_dict['img'],
        # nn.Sequential(self.feature_extractor.first_stage_feature,
        # self.feature_extractor.first_stage_cls_feature))

        prediction_dict = {}

        # base model
        base_feat = self.feature_extractor.first_stage_feature(
            feed_dict['img'])
        feed_dict.update({'base_feat': base_feat})
        self.add_feat('base_feat', base_feat)

        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))

        # shape(N,num_proposals,5)
        # pre subsample for reduce consume of memory
        if self.training:
            self.pre_subsample(prediction_dict, feed_dict)
        rois_batch = prediction_dict['rois_batch']

        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        pooled_feat = self.rcnn_pooling(base_feat, rois_batch.view(-1, 5))
        pooled_feat = F.relu(self.rcnn_conv(pooled_feat), inplace=True)

        pooled_feat_cls = self.rcnn_pooled_feat_cls(pooled_feat.detach())
        pooled_feat_bbox = self.rcnn_pooled_feat_bbox(pooled_feat)

        #  classification
        pooled_feat_cls = self.feature_extractor.third_stage_feature(
            pooled_feat_cls)
        pooled_feat_cls = pooled_feat_cls.mean(3).mean(2)
        rcnn_cls_scores = self.rcnn_cls_pred(pooled_feat_cls)

        rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)

        # regression
        pooled_feat_reg = self.feature_extractor.second_stage_feature(
            pooled_feat_bbox)
        pooled_feat_reg = pooled_feat_reg.mean(3).mean(2)

        rcnn_bbox_preds = self.rcnn_bbox_pred(pooled_feat_reg)

        prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds
        prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores

        # used for track
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][
            proposals_order]

        return prediction_dict

    def unfreeze_part_modules(self, model):
        #  model = self.feature_extractor.third_stage_feature
        for param in model.parameters():
            param.requires_grad = True

        #  model = self.feature_extractor.first_stage_feature

        # def freeze_part_modules(self):
        # pass

        # def rcnn_cls_pred(pooled_feat)

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)
        # if self.training_stage == 'cls':
        # self.freeze_modules()
        # unfreeze part
        # models = [

    # #  self.feature_extractor.first_stage_feature,
    # #  self.feature_extractor.second_stage_feature,
    # self.feature_extractor.third_stage_feature
    # ]
    # for model in models:
    # self.unfreeze_part_modules(model)

    def init_modules(self):
        self.feature_extractor = ResNetFeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        self.rcnn_cls_pred = nn.Linear(2048, self.n_classes)
        if self.reduce:
            in_channels = 2048
        else:
            in_channels = 2048 * 4 * 4
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4)
        else:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2)
        else:
            self.rcnn_cls_loss = functools.partial(F.cross_entropy,
                                                   reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

        # decouple cls and bbox
        self.rcnn_conv = nn.Conv2d(1024, 512, 3, 1, 1, bias=True)
        self.rcnn_pooled_feat_cls = nn.Conv2d(512, 1024, 1, 1, 0)
        self.rcnn_pooled_feat_bbox = nn.Conv2d(512, 1024, 1, 1, 0)

    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']

        # some submodule config
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])
        #  self.sampler = DetectionSampler({'fg_fraction': 1})

        # self.reduce = model_config.get('reduce')
        self.reduce = True
        self.visualizer = FeatVisualizer()

    def pre_subsample(self, prediction_dict, feed_dict):
        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        cls_criterion = None
        pos_indicator = rcnn_reg_weights > 0
        indicator = rcnn_cls_weights > 0
        #  indicator = None

        # subsample from all
        # shape (N,M)
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        # check
        assert num_cls_coeff, 'bug happens'
        assert num_reg_coeff, 'bug happens'

        prediction_dict[
            'rcnn_cls_weights'] = rcnn_cls_weights / num_cls_coeff.float()
        prediction_dict[
            'rcnn_reg_weights'] = rcnn_reg_weights / num_reg_coeff.float()
        prediction_dict['rcnn_cls_targets'] = rcnn_cls_targets[
            batch_sampled_mask]
        prediction_dict['fake_match'] = self.target_assigner.analyzer.match[
            batch_sampled_mask]
        prediction_dict['rcnn_reg_targets'] = rcnn_reg_targets[
            batch_sampled_mask]

        # update rois_batch
        prediction_dict['rois_batch'] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)

        if not self.training:
            # used for track
            proposals_order = prediction_dict['proposals_order']

            prediction_dict['proposals_order'] = proposals_order[
                batch_sampled_mask]

    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        # targets and weights
        rcnn_cls_weights = prediction_dict['rcnn_cls_weights']
        rcnn_reg_weights = prediction_dict['rcnn_reg_weights']

        rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
        rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

        # classification loss
        rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
        rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores, rcnn_cls_targets)
        rcnn_cls_loss *= rcnn_cls_weights
        rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

        # bounding box regression L1 loss
        rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                             rcnn_reg_targets).sum(dim=-1)
        rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

        # loss weights has no gradients
        loss_dict['rcnn_cls_loss'] = rcnn_cls_loss
        loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        # add rcnn_cls_targets to get the statics of rpn
        # loss_dict['rcnn_cls_targets'] = rcnn_cls_targets
        rcnn_cls_probs = prediction_dict['rcnn_cls_probs']
        #  fake_match = self.target_assigner.analyzer.match
        fake_match = prediction_dict['fake_match']
        num_gt = feed_dict['gt_labels'].numel()
        self.target_assigner.analyzer.analyze_ap(fake_match,
                                                 rcnn_cls_probs[:, 1],
                                                 num_gt,
                                                 thresh=0.5)
        #  prediction_dict['rcnn_reg_weights'] = rcnn_reg_weights

        return loss_dict

    def loss_new(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        # import ipdb
        # ipdb.set_trace()
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        # cls_criterion = None
        # pos_indicator = rcnn_reg_weights > 0
        indicator = rcnn_cls_weights > 0
        pos_indicator = indicator
        # indicator = None

        rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
        rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores,
                                           rcnn_cls_targets[0])

        # bounding box regression L1 loss
        rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                             rcnn_reg_targets[0]).sum(dim=-1)

        cls_criterion = rcnn_cls_loss * rcnn_cls_weights + rcnn_bbox_loss * rcnn_reg_weights
        # subsample from all
        # shape (N,M)
        # import ipdb
        # ipdb.set_trace()
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights * batch_sampled_mask.type_as(
            rcnn_cls_weights)
        num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
        # check
        assert num_cls_coeff, 'bug happens'

        rcnn_cls_weights = rcnn_cls_weights / num_cls_coeff.float()

        # import ipdb
        # ipdb.set_trace()
        # rcnn_cls_targets *= batch_sampled_mask.type_as(rcnn_cls_targets)
        # rcnn_reg_targets *= batch_sampled_mask.type_as(rcnn_reg_targets)

        # targets and weights
        # rcnn_cls_weights = prediction_dict['rcnn_cls_weights']
        # rcnn_reg_weights = prediction_dict['rcnn_reg_weights']

        # rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
        # rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

        # classification loss

        # import ipdb
        # ipdb.set_trace()
        rcnn_cls_loss *= rcnn_cls_weights[0]
        rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

        # bbox reg
        rcnn_reg_weights *= batch_sampled_mask.type_as(rcnn_reg_weights)
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        assert num_reg_coeff, 'bug happens'
        rcnn_reg_weights = rcnn_reg_weights / num_reg_coeff.float()

        rcnn_bbox_loss *= rcnn_reg_weights[0]
        # rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

        # # loss weights has no gradients
        loss_dict['rcnn_cls_loss'] = rcnn_cls_loss
        # loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        # add rcnn_cls_targets to get the statics of rpn
        # loss_dict['rcnn_cls_targets'] = rcnn_cls_targets

        # analysis precision
        rcnn_cls_probs = prediction_dict['rcnn_cls_probs']
        fake_match = self.target_assigner.analyzer.match
        num_gt = feed_dict['gt_labels'].numel()
        self.target_assigner.analyzer.analyze_ap(fake_match,
                                                 rcnn_cls_probs[:, 1],
                                                 num_gt,
                                                 thresh=0.5)
        prediction_dict['rcnn_reg_weights'] = rcnn_reg_weights
        return loss_dict