예제 #1
0
    def init_param(self, model_config):
        self.in_channels = model_config['din']
        self.post_nms_topN = model_config['post_nms_topN']
        self.pre_nms_topN = model_config['pre_nms_topN']
        self.nms_thresh = model_config['nms_thresh']
        self.use_focal_loss = model_config['use_focal_loss']

        # anchor generator
        self.anchor_generator = anchor_generators.build(
            model_config['anchor_generator_config'])
        self.num_anchors = self.anchor_generator.num_anchors
        self.nc_bbox_out = 4 * self.num_anchors
        self.nc_score_out = self.num_anchors * 2

        self.target_generators = TargetGenerator(
            model_config['target_generator_config'])
예제 #2
0
    def init_param(self, model_config):
        # including bg
        self.num_classes = len(model_config['classes']) + 1
        self.in_channels = model_config.get('in_channels', 128)
        self.num_regress = model_config.get('num_regress', 4)
        self.feature_extractor_config = model_config[
            'feature_extractor_config']

        self.target_generators = TargetGenerator(
            model_config['target_generator_config'])
        self.anchor_generator = anchor_generators.build(
            model_config['anchor_generator_config'])

        self.num_anchors = self.anchor_generator.num_anchors
        input_size = torch.tensor(model_config['input_size']).float()
        self.anchors = self.anchor_generator.generate(input_size)

        self.use_focal_loss = model_config['use_focal_loss']
예제 #3
0
    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes) + 1
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.truncated = model_config['truncated']
        self.use_focal_loss = model_config['use_focal_loss']

        # some submodule config
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        self.num_stages = len(model_config['target_generator_config'])
        self.target_generators = [
            TargetGenerator(model_config['target_generator_config'][i])
            for i in range(self.num_stages)
        ]
예제 #4
0
class TwoStageRetinaLayer(Model):
    def init_param(self, model_config):
        # including bg
        self.num_classes = len(model_config['classes']) + 1
        self.in_channels = model_config.get('in_channels', 128)
        self.num_regress = model_config.get('num_regress', 4)
        self.feature_extractor_config = model_config[
            'feature_extractor_config']

        self.target_generators = TargetGenerator(
            model_config['target_generator_config'])
        self.anchor_generator = anchor_generators.build(
            model_config['anchor_generator_config'])

        self.num_anchors = self.anchor_generator.num_anchors
        input_size = torch.tensor(model_config['input_size']).float()

        self.normlize_anchor = False
        self.anchors = self.anchor_generator.generate(
            input_size, normalize=self.normlize_anchor)

        self.use_focal_loss = model_config['use_focal_loss']

    def init_modules(self):

        in_channels = self.in_channels

        self.loc_feature1 = nn.Sequential(
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
        )

        self.loc_feature2 = nn.Sequential(
            nn.Conv2d(in_channels * 2, in_channels, kernel_size=1, stride=1),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
        )

        self.cls_feature1 = nn.Sequential(
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
        )

        self.cls_feature2 = nn.Sequential(
            nn.Conv2d(in_channels * 2, in_channels, kernel_size=1, stride=1),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
        )

        self.os_out = nn.Conv2d(in_channels=in_channels,
                                out_channels=self.num_anchors * 2,
                                kernel_size=1)
        self.cls_out = nn.Conv2d(in_channels,
                                 self.num_anchors * self.num_classes,
                                 kernel_size=1)
        self.box_out1 = nn.Conv2d(in_channels,
                                  out_channels=self.num_anchors *
                                  self.num_regress,
                                  kernel_size=1)
        self.box_out2 = nn.Conv2d(in_channels,
                                  out_channels=self.num_anchors *
                                  self.num_regress,
                                  kernel_size=1)

        self.feature_extractor = PRNetFeatureExtractor(
            self.feature_extractor_config)

        self.two_step_loss = TwoStepFocalLoss(self.num_classes)
        self.rpn_bbox_loss = nn.SmoothL1Loss(reduction='none')
        if self.use_focal_loss:
            # optimized too slowly
            self.rpn_cls_loss = OldFocalLoss(self.num_classes)
            # fg or bg
            self.rpn_os_loss = OldFocalLoss(2)
        else:
            self.rpn_cls_loss = nn.CrossEntropyLoss(reduction='none')
            self.rpn_os_loss = nn.CrossEntropyLoss(reduction='none')

        self.retina_loss = RetinaNetLoss(self.num_classes)

    def forward(self, feed_dict):
        features = self.feature_extractor(feed_dict[constants.KEY_IMAGE])
        y_locs1 = []
        y_locs2 = []
        y_os = []
        y_cls = []

        for i, x in enumerate(features):
            # location out
            loc_feature = self.loc_feature1(x)
            loc1 = self.box_out1(loc_feature)

            N = loc1.size(0)
            loc1 = loc1.permute(0, 2, 3, 1).contiguous()
            loc1 = loc1.view(N, -1, self.num_regress)
            y_locs1.append(loc1)

            loc_feature = torch.cat([x, loc_feature], dim=1)
            loc_feature = self.loc_feature2(loc_feature)
            loc2 = self.box_out2(loc_feature)

            N = loc2.size(0)
            loc2 = loc2.permute(0, 2, 3, 1).contiguous()
            loc2 = loc2.view(N, -1, self.num_regress)
            loc2 += loc1
            y_locs2.append(loc2)

            # os out
            cls_feature = self.cls_feature1(x)
            os_out = self.os_out(cls_feature)
            os_out = os_out.permute(0, 2, 3, 1).contiguous()
            # _size = os_out.size(1)
            os_out = os_out.view(N, -1, 2)
            y_os.append(os_out)

            cls_feature = torch.cat([x, cls_feature], dim=1)
            cls_feature = self.cls_feature2(cls_feature)
            cls_out = self.cls_out(cls_feature)

            cls_out = cls_out.permute(0, 2, 3, 1).contiguous()
            cls_out = cls_out.view(N, -1, self.num_classes)
            y_cls.append(cls_out)

        loc1_preds = torch.cat(y_locs1, dim=1)
        loc2_preds = torch.cat(y_locs2, dim=1)
        os_preds = torch.cat(y_os, dim=1)
        cls_preds = torch.cat(y_cls, dim=1)

        # if self.training:
        # prediction_dict = {
        # 'loc1_preds': loc1_preds,
        # 'loc2_preds': loc2_preds,
        # 'os_preds': os_preds,
        # 'cls_preds': cls_preds
        # }

        # stats = {
        # 'recall': torch.tensor([1, 1]).to('cuda').float().unsqueeze(0)
        # }
        # prediction_dict[constants.KEY_STATS] = [stats]
        # else:
        # prediction_dict = {}
        # cls_probs = F.softmax(cls_preds, dim=-1)
        # os_probs = F.softmax(os_preds, dim=-1)[:, :, 1:]
        # os_probs[os_probs <= 0.4] = 0
        # prediction_dict[constants.KEY_CLASSES] = cls_probs * os_probs
        # # prediction_dict[constants.KEY_OBJECTNESS] = os_preds

        # image_info = feed_dict[constants.KEY_IMAGE_INFO]
        # variances = [0.1, 0.2]
        # default_boxes = feed_dict['default_boxes'][0]
        # new_default_boxes = torch.cat([
        # default_boxes[:, :2] - default_boxes[:, 2:] / 2,
        # default_boxes[:, :2] + default_boxes[:, 2:] / 2
        # ], 1)
        # xymin = loc2_preds[0, :, :2] * variances[
        # 0] * default_boxes[:, 2:] + new_default_boxes[:, :2]
        # xymax = loc2_preds[0, :, 2:] * variances[
        # 0] * default_boxes[:, 2:] + new_default_boxes[:, 2:]
        # proposals = torch.cat([xymin, xymax], 1).unsqueeze(0)  # [8732,4]

        # image_info = image_info.unsqueeze(-1).unsqueeze(-1)
        # proposals[:, :, ::
        # 2] = proposals[:, :, ::
        # 2] * image_info[:, 1] / image_info[:, 3]
        # proposals[:, :, 1::
        # 2] = proposals[:, :, 1::
        # 2] * image_info[:, 0] / image_info[:, 2]
        # prediction_dict[constants.KEY_BOXES_2D] = proposals
        # return prediction_dict

        image_info = feed_dict[constants.KEY_IMAGE_INFO]

        batch_size = loc1_preds.shape[0]

        anchors = self.anchors.cuda()
        anchors = anchors.repeat(batch_size, 1, 1)

        coder = bbox_coders.build({'type': constants.KEY_BOXES_2D})
        proposals = coder.decode_batch(loc2_preds, anchors).detach()

        # if self.normlize_anchor:
        # denormalize
        # h = image_info[:, 0].unsqueeze(-1).unsqueeze(-1)
        # w = image_info[:, 1].unsqueeze(-1).unsqueeze(-1)
        # proposals[:, :, ::2] = proposals[:, :, ::2] * w
        # proposals[:, :, 1::2] = proposals[:, :, 1::2] * h

        cls_probs = F.softmax(cls_preds.detach(), dim=-1)
        os_probs = F.softmax(os_preds.detach(), dim=-1)[:, :, 1:]
        os_probs[os_probs <= 0.4] = 0
        final_probs = cls_probs * os_probs
        # import ipdb
        # ipdb.set_trace()
        #  final_probs = cls_probs
        image_info = feed_dict[constants.KEY_IMAGE_INFO].unsqueeze(
            -1).unsqueeze(-1)

        prediction_dict = {}
        if self.training:

            # anchors = prediction_dict['anchors']
            anchors_dict = {}
            anchors_dict[constants.KEY_PRIMARY] = anchors
            anchors_dict[constants.KEY_BOXES_2D] = loc1_preds
            anchors_dict[constants.KEY_BOXES_2D_REFINE] = loc2_preds
            anchors_dict[constants.KEY_CLASSES] = cls_preds
            anchors_dict[constants.KEY_OBJECTNESS] = os_preds
            # anchors_dict[constants.KEY_FINAL_PROBS] = final_probs

            gt_dict = {}
            gt_dict[constants.KEY_PRIMARY] = feed_dict[
                constants.KEY_LABEL_BOXES_2D]
            gt_dict[constants.KEY_CLASSES] = None
            gt_dict[constants.KEY_BOXES_2D] = None
            gt_dict[constants.KEY_OBJECTNESS] = None
            gt_dict[constants.KEY_BOXES_2D_REFINE] = None

            auxiliary_dict = {}
            label_boxes_2d = feed_dict[constants.KEY_LABEL_BOXES_2D]
            if self.normlize_anchor:
                label_boxes_2d[:, :, ::
                               2] = label_boxes_2d[:, :, ::2] / image_info[:,
                                                                           1]
                label_boxes_2d[:, :,
                               1::2] = label_boxes_2d[:, :,
                                                      1::2] / image_info[:, 0]
            auxiliary_dict[constants.KEY_BOXES_2D] = label_boxes_2d
            gt_labels = feed_dict[constants.KEY_LABEL_CLASSES]
            auxiliary_dict[constants.KEY_CLASSES] = gt_labels
            auxiliary_dict[constants.KEY_NUM_INSTANCES] = feed_dict[
                constants.KEY_NUM_INSTANCES]
            auxiliary_dict[constants.KEY_PROPOSALS] = anchors

            proposals_dict, targets, stats = self.target_generators.generate_targets(
                anchors_dict, gt_dict, auxiliary_dict, subsample=False)

            # recall
            anchors_dict[constants.KEY_PRIMARY] = proposals
            _, _, second_stage_stats = self.target_generators.generate_targets(
                anchors_dict, gt_dict, auxiliary_dict, subsample=False)

            # precision
            fg_probs, _ = final_probs[:, :, 1:].max(dim=-1)
            fake_match = auxiliary_dict[constants.KEY_FAKE_MATCH]
            second_stage_stats.update(
                Analyzer.analyze_precision(
                    fake_match,
                    fg_probs,
                    feed_dict[constants.KEY_NUM_INSTANCES],
                    thresh=0.3))

            prediction_dict[constants.KEY_STATS] = [stats, second_stage_stats]
            prediction_dict[constants.KEY_TARGETS] = targets
        else:

            prediction_dict[constants.KEY_CLASSES] = final_probs
            # prediction_dict[constants.KEY_OBJECTNESS] = os_preds

            proposals[:, :, ::2] = proposals[:, :, ::2] / image_info[:, 3]
            proposals[:, :, 1::2] = proposals[:, :, 1::2] / image_info[:, 2]
            prediction_dict[constants.KEY_BOXES_2D] = proposals
        return prediction_dict

    # def loss(self, prediction_dict, feed_dict):
    # # import ipdb
    # # ipdb.set_trace()

    # target = feed_dict['gt_target']
    # loc1_preds = prediction_dict['loc1_preds']
    # loc2_preds = prediction_dict['loc2_preds']
    # conf_preds = prediction_dict['cls_preds']
    # os_preds = prediction_dict['os_preds']
    # bbox, labels, os_gt, _ = target

    # loc_loss, os_loss, conf_loss = self.two_step_loss(
    # loc1_preds,
    # loc2_preds,
    # bbox,
    # conf_preds,
    # labels.long(),
    # os_preds,
    # os_gt,
    # is_print=False)

    # # loss
    # loss_dict = {}

    # # loss_dict['total_loss'] = total_loss
    # loss_dict['loc_loss'] = loc_loss
    # loss_dict['os_loss'] = os_loss
    # loss_dict['conf_loss'] = conf_loss

    # return loss_dict

    def loss(self, prediction_dict, feed_dict):
        loss_dict = {}

        targets = prediction_dict[constants.KEY_TARGETS]

        cls_target = targets[constants.KEY_CLASSES]
        loc1_target = targets[constants.KEY_BOXES_2D]
        loc2_target = targets[constants.KEY_BOXES_2D_REFINE]
        os_target = targets[constants.KEY_OBJECTNESS]

        loc1_preds = loc1_target['pred']
        loc2_preds = loc2_target['pred']
        loc1_target = loc1_target['target']
        loc2_target = loc2_target['target']
        assert loc1_target.shape == loc2_target.shape
        loc_target = loc1_target

        conf_preds = cls_target['pred']
        conf_target = cls_target['target']
        conf_weight = cls_target['weight']
        conf_target[conf_weight == 0] = -1

        os_preds = os_target['pred']
        os_target_ = os_target['target']
        os_weight = os_target['weight']
        os_target_[os_weight == 0] = -1

        loc_loss, os_loss, conf_loss = self.two_step_loss(loc1_preds,
                                                          loc2_preds,
                                                          loc_target,
                                                          conf_preds,
                                                          conf_target,
                                                          os_preds,
                                                          os_target_,
                                                          is_print=False)

        # loss

        # loss_dict['total_loss'] = total_loss
        loss_dict['loc_loss'] = loc_loss
        loss_dict['os_loss'] = os_loss
        loss_dict['conf_loss'] = conf_loss

        return loss_dict

    def loss_orig(self, prediction_dict, feed_dict):
        # loss for cls
        loss_dict = {}

        targets = prediction_dict[constants.KEY_TARGETS]
        cls_target = targets[constants.KEY_CLASSES]
        loc1_target = targets[constants.KEY_BOXES_2D]
        loc2_target = targets[constants.KEY_BOXES_2D_REFINE]
        os_target = targets[constants.KEY_OBJECTNESS]

        rpn_cls_loss = common_loss.calc_loss(focal_loss_alt(self.num_classes),
                                             cls_target,
                                             normalize=False)
        rpn_loc1_loss = common_loss.calc_loss(self.rpn_bbox_loss, loc1_target)
        rpn_os_loss = common_loss.calc_loss(focal_loss_alt(2),
                                            os_target,
                                            normalize=False)
        rpn_loc2_loss = common_loss.calc_loss(self.rpn_bbox_loss, loc2_target)

        cls_targets = cls_target['target']
        pos = cls_targets > 0  # [N,#anchors]
        num_pos = pos.data.long().sum().clamp(min=1).float()

        cls_loss = rpn_cls_loss / num_pos

        os_loss = rpn_os_loss / num_pos

        loss_dict.update({
            'rpn_cls_loss': cls_loss,
            'rpn_loc1_loss': rpn_loc1_loss * 0.35,
            'rpn_loc2_loss': rpn_loc2_loss * 0.5,
            'rpn_os_loss': os_loss * 10
        })

        return loss_dict

    def loss_retina(self, prediction_dict, feed_dict):
        loss_dict = {}

        targets = prediction_dict[constants.KEY_TARGETS]
        cls_target = targets[constants.KEY_CLASSES]
        loc1_target = targets[constants.KEY_BOXES_2D]
        loc2_target = targets[constants.KEY_BOXES_2D_REFINE]
        os_target = targets[constants.KEY_OBJECTNESS]

        conf_weight = cls_target['weight']
        conf_target = cls_target['target']
        conf_target[conf_weight == 0] = -1

        os_preds = os_target['pred']
        os_target_ = os_target['target']
        os_weight = os_target['weight']
        os_target_[os_weight == 0] = -1

        total_loss = self.retina_loss(loc1_target['pred'], loc2_target['pred'],
                                      loc1_target['target'],
                                      cls_target['pred'], conf_target,
                                      os_preds, os_target_)
        loss_dict['total_loss'] = total_loss
        return loss_dict
예제 #5
0
class TwoStageRetinaLayer(Model):
    def init_param(self, model_config):
        # including bg
        self.num_classes = len(model_config['classes']) + 1
        self.in_channels = model_config.get('in_channels', 128)
        self.num_regress = model_config.get('num_regress', 4)
        self.feature_extractor_config = model_config[
            'feature_extractor_config']

        self.target_generators = TargetGenerator(
            model_config['target_generator_config'])
        self.anchor_generator = anchor_generators.build(
            model_config['anchor_generator_config'])

        self.num_anchors = self.anchor_generator.num_anchors
        input_size = torch.tensor(model_config['input_size']).float()
        self.anchors = self.anchor_generator.generate(input_size)

        self.use_focal_loss = model_config['use_focal_loss']

    def init_modules(self):

        in_channels = self.in_channels

        self.loc_feature1 = nn.Sequential(
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
        )

        self.loc_feature2 = nn.Sequential(
            nn.Conv2d(in_channels * 2, in_channels, kernel_size=1, stride=1),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
        )

        self.cls_feature1 = nn.Sequential(
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
        )

        self.cls_feature2 = nn.Sequential(
            nn.Conv2d(in_channels * 2, in_channels, kernel_size=1, stride=1),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
        )

        self.corners_feature = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(inplace=True),
        )

        # self.orient_feature = nn.Sequential(
        # nn.Conv2d(
        # in_channels, in_channels, kernel_size=1, stride=1),
        # nn.Conv2d(
        # in_channels, in_channels, kernel_size=3, stride=1, padding=1),
        # nn.ReLU(inplace=True),
        # nn.Conv2d(
        # in_channels, in_channels, kernel_size=3, stride=1, padding=1),
        # nn.ReLU(inplace=True), )

        self.os_out = nn.Conv2d(in_channels=in_channels,
                                out_channels=self.num_anchors * 2,
                                kernel_size=1)
        self.cls_out = nn.Conv2d(in_channels,
                                 self.num_anchors * self.num_classes,
                                 kernel_size=1)
        self.box_out1 = nn.Conv2d(in_channels,
                                  out_channels=self.num_anchors *
                                  self.num_regress,
                                  kernel_size=1)
        self.box_out2 = nn.Conv2d(in_channels,
                                  out_channels=self.num_anchors *
                                  self.num_regress,
                                  kernel_size=1)
        self.corners_out = nn.Conv2d(in_channels,
                                     out_channels=self.num_anchors *
                                     (2 + 4 + 1),
                                     kernel_size=1)
        # self.orient_out = nn.Conv2d(
        # in_channels, out_channels=self.num_anchors * 5, kernel_size=1)

        self.feature_extractor = PRNetFeatureExtractor(
            self.feature_extractor_config)

        self.two_step_loss = TwoStepFocalLoss(self.num_classes)
        self.rpn_bbox_loss = nn.SmoothL1Loss(reduction='none')
        if self.use_focal_loss:
            # optimized too slowly
            self.rpn_cls_loss = OldFocalLoss(self.num_classes)
            # fg or bg
            self.rpn_os_loss = OldFocalLoss(2)
        else:
            self.rpn_cls_loss = nn.CrossEntropyLoss(reduction='none')
            self.rpn_os_loss = nn.CrossEntropyLoss(reduction='none')

        self.retina_loss = RetinaNetLoss(self.num_classes)
        self.l1_loss = nn.L1Loss(reduction='none')
        self.smooth_l1_loss = nn.SmoothL1Loss(reduction='none')

        # self.rcnn_corners_loss = CornersLoss(
        # use_filter=True, training_depth=False)

        # self.rcnn_orient_preds = nn.Linear(1024, 5)

        # self.rcnn_dim_preds = nn.Linear(1024, 3)

        # self.rcnn_orient_loss = OrientationLoss()

        # self.rcnn_corners_preds = nn.Linear(1024, 41)

    def forward(self, feed_dict):
        features = self.feature_extractor(feed_dict[constants.KEY_IMAGE])
        y_locs1 = []
        y_locs2 = []
        y_os = []
        y_cls = []
        # y_dims = []
        # y_orients = []
        y_corners = []

        for i, x in enumerate(features):
            # location out
            loc_feature = self.loc_feature1(x)
            loc1 = self.box_out1(loc_feature)

            N = loc1.size(0)
            loc1 = loc1.permute(0, 2, 3, 1).contiguous()
            loc1 = loc1.view(N, -1, self.num_regress)
            y_locs1.append(loc1)

            loc_feature = torch.cat([x, loc_feature], dim=1)
            loc_feature = self.loc_feature2(loc_feature)
            loc2 = self.box_out2(loc_feature)

            N = loc2.size(0)
            loc2 = loc2.permute(0, 2, 3, 1).contiguous()
            loc2 = loc2.view(N, -1, self.num_regress)
            loc2 += loc1
            y_locs2.append(loc2)

            # os out
            cls_feature = self.cls_feature1(x)
            os_out = self.os_out(cls_feature)
            os_out = os_out.permute(0, 2, 3, 1).contiguous()
            # _size = os_out.size(1)
            os_out = os_out.view(N, -1, 2)
            y_os.append(os_out)

            cls_feature = torch.cat([x, cls_feature], dim=1)
            cls_feature = self.cls_feature2(cls_feature)
            cls_out = self.cls_out(cls_feature)

            cls_out = cls_out.permute(0, 2, 3, 1).contiguous()
            cls_out = cls_out.view(N, -1, self.num_classes)
            y_cls.append(cls_out)

            # dim out
            corners_feature = self.corners_feature(x)
            corners_out = self.corners_out(corners_feature)
            corners_out = corners_out.permute(0, 2, 3, 1).contiguous()
            corners_out = corners_out.view(N, -1, 7)
            y_corners.append(corners_out)

            # orient out
            # orient_feature = self.orient_feature(x)
            # orient_out = self.orient_out(orient_feature)
            # orient_out = orient_out.permute(0, 2, 3, 1).contiguous()
            # orient_out = orient_out.view(N, -1, 5)
            # y_orients.append(orient_out)

        loc1_preds = torch.cat(y_locs1, dim=1)
        loc2_preds = torch.cat(y_locs2, dim=1)
        os_preds = torch.cat(y_os, dim=1)
        cls_preds = torch.cat(y_cls, dim=1)
        # orient_preds = torch.cat(y_orients, dim=1)
        corners_preds = torch.cat(y_corners, dim=1)

        image_info = feed_dict[constants.KEY_IMAGE_INFO]

        batch_size = loc1_preds.shape[0]

        anchors = self.anchors.cuda()
        anchors = anchors.repeat(batch_size, 1, 1)

        coder = bbox_coders.build(
            self.target_generators.target_generator_config['coder_config'])
        proposals = coder.decode_batch(loc2_preds, anchors).detach()

        cls_probs = F.softmax(cls_preds.detach(), dim=-1)
        os_probs = F.softmax(os_preds.detach(), dim=-1)[:, :, 1:]
        os_probs[os_probs <= 0.4] = 0
        final_probs = cls_probs * os_probs
        # import ipdb
        # ipdb.set_trace()
        #  final_probs = cls_probs

        coder = bbox_coders.build({'type': constants.KEY_CORNERS_3D_GRNET})
        # decoded_dim_preds = coder.decode_batch(
        # dim_preds, feed_dict[constants.KEY_MEAN_DIMS],
        # final_probs).detach()
        # coder = bbox_coders.build({'type': constants.KEY_ORIENTS_V2})
        # # use rpn proposals to decode
        # decoded_orient_preds = coder.decode_batch(
        # orient_preds, proposals,
        # feed_dict[constants.KEY_STEREO_CALIB_P2]).detach()

        prediction_dict = {}
        if self.training:

            # anchors = prediction_dict['anchors']
            anchors_dict = {}
            anchors_dict[constants.KEY_PRIMARY] = anchors
            anchors_dict[constants.KEY_BOXES_2D] = loc1_preds
            anchors_dict[constants.KEY_BOXES_2D_REFINE] = loc2_preds
            anchors_dict[constants.KEY_CLASSES] = cls_preds
            anchors_dict[constants.KEY_OBJECTNESS] = os_preds
            # anchors_dict[constants.KEY_DIMS] = dim_preds
            # anchors_dict[constants.KEY_ORIENTS_V2] = orient_preds
            anchors_dict[constants.KEY_CORNERS_3D_GRNET] = corners_preds

            # anchors_dict[constants.KEY_FINAL_PROBS] = final_probs

            gt_dict = {}
            gt_dict[constants.KEY_PRIMARY] = feed_dict[
                constants.KEY_LABEL_BOXES_2D]
            gt_dict[constants.KEY_CLASSES] = None
            gt_dict[constants.KEY_BOXES_2D] = None
            gt_dict[constants.KEY_OBJECTNESS] = None
            gt_dict[constants.KEY_BOXES_2D_REFINE] = None
            # gt_dict[constants.KEY_ORIENTS_V2] = None
            # gt_dict[constants.KEY_DIMS] = None
            gt_dict[constants.KEY_CORNERS_3D_GRNET] = None

            auxiliary_dict = {}
            auxiliary_dict[constants.KEY_BOXES_2D] = feed_dict[
                constants.KEY_LABEL_BOXES_2D]
            auxiliary_dict[constants.KEY_STEREO_CALIB_P2] = feed_dict[
                constants.KEY_STEREO_CALIB_P2]
            auxiliary_dict[constants.KEY_BOXES_3D] = feed_dict[
                constants.KEY_LABEL_BOXES_3D]
            gt_labels = feed_dict[constants.KEY_LABEL_CLASSES]
            auxiliary_dict[constants.KEY_CLASSES] = gt_labels
            auxiliary_dict[constants.KEY_NUM_INSTANCES] = feed_dict[
                constants.KEY_NUM_INSTANCES]
            auxiliary_dict[constants.KEY_PROPOSALS] = anchors
            auxiliary_dict[constants.KEY_MEAN_DIMS] = feed_dict[
                constants.KEY_MEAN_DIMS]
            auxiliary_dict[constants.KEY_IMAGE_INFO] = feed_dict[
                constants.KEY_IMAGE_INFO]

            proposals_dict, targets, stats = self.target_generators.generate_targets(
                anchors_dict, gt_dict, auxiliary_dict, subsample=False)

            # recall
            anchors_dict[constants.KEY_PRIMARY] = proposals
            _, _, second_stage_stats = self.target_generators.generate_targets(
                anchors_dict, gt_dict, auxiliary_dict, subsample=False)

            # precision
            fg_probs, _ = final_probs[:, :, 1:].max(dim=-1)
            fake_match = auxiliary_dict[constants.KEY_FAKE_MATCH]
            second_stage_stats.update(
                Analyzer.analyze_precision(
                    fake_match,
                    fg_probs,
                    feed_dict[constants.KEY_NUM_INSTANCES],
                    thresh=0.3))

            prediction_dict[constants.KEY_STATS] = [stats, second_stage_stats]
            prediction_dict[constants.KEY_TARGETS] = targets
            prediction_dict[constants.KEY_PROPOSALS] = anchors
        else:

            prediction_dict[constants.KEY_CLASSES] = final_probs
            # prediction_dict[constants.KEY_OBJECTNESS] = os_preds
            # prediction_dict[constants.KEY_ORIENTS_V2] = decoded_orient_preds
            # prediction_dict[constants.KEY_DIMS] = decoded_dim_preds

            image_info = feed_dict[constants.KEY_IMAGE_INFO]
            proposals[:, :, ::
                      2] = proposals[:, :, ::2] / image_info[:, 3].unsqueeze(
                          -1).unsqueeze(-1)
            proposals[:, :,
                      1::2] = proposals[:, :,
                                        1::2] / image_info[:, 2].unsqueeze(
                                            -1).unsqueeze(-1)
            prediction_dict[constants.KEY_BOXES_2D] = proposals
            prediction_dict['rcnn_3d'] = torch.ones_like(proposals)

            corners_preds = coder.decode_batch(
                corners_preds.detach(), proposals,
                feed_dict[constants.KEY_STEREO_CALIB_P2])
            prediction_dict[constants.KEY_CORNERS_2D] = corners_preds

        if self.training:
            loss_dict = self.loss(prediction_dict, feed_dict)
            return prediction_dict, loss_dict
        else:
            return prediction_dict

    def calc_local_corners(self, dims, ry):
        # import ipdb
        # ipdb.set_trace()
        h = dims[:, 0]
        w = dims[:, 1]
        l = dims[:, 2]
        zeros = torch.zeros_like(l).type_as(l)
        # rotation_matrix = geometry_utils.torch_ry_to_rotation_matrix(ry)

        zeros = torch.zeros_like(ry[:, 0])
        ones = torch.ones_like(ry[:, 0])
        cos = torch.cos(ry[:, 0])
        sin = torch.sin(ry[:, 0])
        # norm = torch.norm(ry, dim=-1)
        cos = cos
        sin = sin

        rotation_matrix = torch.stack(
            [cos, zeros, sin, zeros, ones, zeros, -sin, zeros, cos],
            dim=-1).reshape(-1, 3, 3)

        x_corners = torch.stack(
            [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2],
            dim=0)
        y_corners = torch.stack([zeros, zeros, zeros, zeros, -h, -h, -h, -h],
                                dim=0)
        z_corners = torch.stack(
            [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2],
            dim=0)

        # shape(N, 3, 8)
        box_points_coords = torch.stack((x_corners, y_corners, z_corners),
                                        dim=0)
        # rotate and translate
        # shape(N, 3, 8)
        corners_3d = torch.bmm(rotation_matrix,
                               box_points_coords.permute(2, 0, 1))

        return corners_3d.permute(0, 2, 1)

    def decode_center_depth(self, dims_preds, final_boxes_2d_xywh, p2):
        f = p2[:, 0, 0]
        h_2d = final_boxes_2d_xywh[:, :, -1] + 1e-6
        h_3d = dims_preds[:, :, 0]
        depth_preds = f.unsqueeze(-1) * h_3d / h_2d
        return depth_preds.unsqueeze(-1)

    def loss(self, prediction_dict, feed_dict):
        loss_dict = {}

        targets = prediction_dict[constants.KEY_TARGETS]

        cls_target = targets[constants.KEY_CLASSES]
        loc1_target = targets[constants.KEY_BOXES_2D]
        loc2_target = targets[constants.KEY_BOXES_2D_REFINE]
        os_target = targets[constants.KEY_OBJECTNESS]
        corners_target = targets[constants.KEY_CORNERS_3D_GRNET]
        # dims_target = targets[constants.KEY_DIMS]
        # orients_target = targets[constants.KEY_ORIENTS_V2]

        loc1_preds = loc1_target['pred']
        loc2_preds = loc2_target['pred']
        loc1_target = loc1_target['target']
        loc2_target = loc2_target['target']
        assert loc1_target.shape == loc2_target.shape
        loc_target = loc1_target

        conf_preds = cls_target['pred']
        conf_target = cls_target['target']
        conf_weight = cls_target['weight']
        conf_target[conf_weight == 0] = -1

        os_preds = os_target['pred']
        os_target_ = os_target['target']
        os_weight = os_target['weight']
        os_target_[os_weight == 0] = -1

        loc_loss, os_loss, conf_loss = self.two_step_loss(loc1_preds,
                                                          loc2_preds,
                                                          loc_target,
                                                          conf_preds,
                                                          conf_target,
                                                          os_preds,
                                                          os_target_,
                                                          is_print=False)

        # import ipdb
        # ipdb.set_trace()
        # 3d loss
        # corners_loss = common_loss.calc_loss(self.rcnn_corners_loss,
        # corners_2d_target)

        # import ipdb
        # ipdb.set_trace()
        preds = corners_target['pred']
        targets = corners_target['target']
        weights = corners_target['weight']
        proposals = prediction_dict[constants.KEY_PROPOSALS]
        p2 = feed_dict[constants.KEY_STEREO_CALIB_P2]
        image_info = feed_dict[constants.KEY_IMAGE_INFO]
        weights = weights.unsqueeze(-1)

        local_corners_gt = targets[:, :, :24]
        location_gt = targets[:, :, 24:27]
        dims_gt = targets[:, :, 27:]
        N, M = local_corners_gt.shape[:2]

        global_corners_gt = (local_corners_gt.view(N, M, 8, 3) +
                             location_gt.view(N, M, 1, 3)).view(N, M, -1)
        center_depth_gt = location_gt[:, :, 2:]

        mean_dims = torch.tensor([1.8, 1.8, 3.7]).type_as(preds)
        dims_preds = torch.exp(preds[:, :, :3]) * mean_dims
        # import ipdb
        # ipdb.set_trace()
        dims_loss = self.l1_loss(dims_preds, dims_gt) * weights
        ry_preds = preds[:, :, 3:4]
        # ray_angle = -torch.atan2(location_gt[:, :, 2],
        # location_gt[:, :, 0])
        # ry_preds = ry_preds + ray_angle.unsqueeze(-1)
        local_corners_preds = []
        # calc local corners preds
        for batch_ind in range(N):
            local_corners_preds.append(
                self.calc_local_corners(dims_preds[batch_ind].detach(),
                                        ry_preds[batch_ind]))
        local_corners_preds = torch.stack(local_corners_preds, dim=0)

        center_2d_deltas_preds = preds[:, :, 4:6]
        center_depth_preds = preds[:, :, 6:]
        # import ipdb
        # ipdb.set_trace()
        # decode center_2d
        proposals_xywh = geometry_utils.torch_xyxy_to_xywh(proposals)
        center_depth_init = self.decode_center_depth(dims_preds,
                                                     proposals_xywh, p2)
        center_depth_preds = center_depth_init * center_depth_preds
        center_2d_preds = (center_2d_deltas_preds * proposals_xywh[:, :, 2:] +
                           proposals_xywh[:, :, :2])
        # center_depth_preds_detach = center_depth_preds.detach()

        # import ipdb
        # ipdb.set_trace()
        # use gt depth to cal loss to make sure the gradient smooth
        location_preds = []
        for batch_ind in range(N):
            location_preds.append(
                geometry_utils.torch_points_2d_to_points_3d(
                    center_2d_preds[batch_ind], center_depth_preds[batch_ind],
                    p2[batch_ind]))
        location_preds = torch.stack(location_preds, dim=0)
        global_corners_preds = (location_preds.view(N, M, 1, 3) +
                                local_corners_preds.view(N, M, 8, 3)).view(
                                    N, M, -1)

        # import ipdb
        # ipdb.set_trace()
        # corners depth loss and center depth loss
        corners_depth_preds = local_corners_preds.view(N, M, 8, 3)[..., -1]
        corners_depth_gt = local_corners_gt.view(N, M, 8, 3)[..., -1]

        # import ipdb
        # ipdb.set_trace()
        center_depth_loss = self.l1_loss(center_depth_preds,
                                         center_depth_gt) * weights

        # location loss
        location_loss = self.l1_loss(location_preds, location_gt) * weights

        # global corners loss
        global_corners_loss = self.l1_loss(global_corners_preds,
                                           global_corners_gt) * weights

        # proj 2d loss
        corners_2d_preds = []
        corners_2d_gt = []
        for batch_ind in range(N):
            corners_2d_preds.append(
                geometry_utils.torch_points_3d_to_points_2d(
                    global_corners_preds[batch_ind].view(-1, 3),
                    p2[batch_ind]))
            corners_2d_gt.append(
                geometry_utils.torch_points_3d_to_points_2d(
                    global_corners_gt[batch_ind].view(-1, 3), p2[batch_ind]))

        corners_2d_preds = torch.stack(corners_2d_preds, dim=0).view(N, M, -1)
        corners_2d_gt = torch.stack(corners_2d_gt, dim=0).view(N, M, -1)

        # image filter
        # import ipdb
        # ipdb.set_trace()
        zeros = torch.zeros_like(image_info[:, 0])
        image_shape = torch.stack(
            [zeros, zeros, image_info[:, 1], image_info[:, 0]], dim=-1)
        image_shape = image_shape.type_as(corners_2d_gt).view(-1, 4)
        image_filter = geometry_utils.torch_window_filter(
            corners_2d_gt.view(N, -1, 2), image_shape,
            deltas=200).float().view(N, M, -1)

        # import ipdb
        # ipdb.set_trace()
        encoded_corners_2d_gt = corners_2d_gt.view(N, M, 8, 2)
        encoded_corners_2d_preds = corners_2d_preds.view(N, M, 8, 2)
        # import ipdb
        # ipdb.set_trace()
        corners_2d_loss = self.l1_loss(encoded_corners_2d_preds.view(
            N, M, -1), encoded_corners_2d_gt.view(N, M, -1)) * weights
        corners_2d_loss = (corners_2d_loss.view(N, M, 8, 2) *
                           image_filter.unsqueeze(-1))
        # import ipdb
        # ipdb.set_trace()
        # mask = self.select_corners(global_corners_gt)
        # mask = mask.unsqueeze(-1).expand_as(corners_2d_loss).float()
        corners_2d_loss = corners_2d_loss.view(N, M, -1)
        corners_depth_loss = self.l1_loss(
            corners_depth_preds, corners_depth_gt) * weights * image_filter

        # import ipdb
        # ipdb.set_trace()
        # corners_3d_gt = []
        # for batch_ind in range(N):
        # corners_3d_gt.append(
        # geometry_utils.torch_points_2d_to_points_3d(
        # corners_2d_preds[batch_ind].view(-1, 2),
        # corners_depth_preds[batch_ind].view(-1), p2[batch_ind]))
        # corners_3d_gt = torch.stack(corners_3d_gt, dim=0).view(N, M, -1)

        # dim_target = targets[stage_ind][3]
        # rcnn_dim_loss = rcnn_dim_loss + common_loss.calc_loss(
        # self.rcnn_bbox_loss, dim_target, True)

        global_corners_loss = self.l1_loss(global_corners_preds,
                                           global_corners_gt) * weights

        # rpn_orients_loss = common_loss.calc_loss(self.rcnn_orient_loss,
        # corners_2d_target) * 100

        # loss

        # import ipdb
        # ipdb.set_trace()
        # loss_dict['total_loss'] = total_loss
        pos = weights > 0  # [N,#anchors]
        num_pos = pos.data.long().sum().clamp(min=1).float()

        loss_dict['loc_loss'] = loc_loss
        loss_dict['os_loss'] = os_loss
        loss_dict['conf_loss'] = conf_loss
        # loss_dict['corners_2d_loss'] = corners_2d_loss.sum() / num_pos * 0.1
        loss_dict['dims_loss'] = dims_loss.sum() / num_pos * 10
        loss_dict['global_corners_loss'] = global_corners_loss.sum(
        ) / num_pos * 10
        loss_dict['location_loss'] = location_loss.sum() / num_pos * 10
        loss_dict['center_depth_loss'] = center_depth_loss.sum() / num_pos * 10
        # loss_dict['orients_loss'] = rpn_orients_loss

        return loss_dict

    def loss_orig(self, prediction_dict, feed_dict):
        # loss for cls
        loss_dict = {}

        targets = prediction_dict[constants.KEY_TARGETS]
        cls_target = targets[constants.KEY_CLASSES]
        loc1_target = targets[constants.KEY_BOXES_2D]
        loc2_target = targets[constants.KEY_BOXES_2D_REFINE]
        os_target = targets[constants.KEY_OBJECTNESS]
        dims_target = targets[constants.KEY_DIMS]
        orients_target = targets[constants.KEY_ORIENTS_V2]

        rpn_cls_loss = common_loss.calc_loss(focal_loss_alt(self.num_classes),
                                             cls_target,
                                             normalize=False)
        rpn_loc1_loss = common_loss.calc_loss(self.rpn_bbox_loss, loc1_target)
        rpn_os_loss = common_loss.calc_loss(focal_loss_alt(2),
                                            os_target,
                                            normalize=False)
        rpn_loc2_loss = common_loss.calc_loss(self.rpn_bbox_loss, loc2_target)

        rpn_dims_loss = common_loss.calc_loss(self.rpn_bbox_loss, dims_target)
        rpn_orients_loss = common_loss.calc_loss(self.rcnn_orient_loss,
                                                 orients_target)

        cls_targets = cls_target['target']
        pos = cls_targets > 0  # [N,#anchors]
        num_pos = pos.data.long().sum().clamp(min=1).float()

        cls_loss = rpn_cls_loss / num_pos

        os_loss = rpn_os_loss / num_pos

        loss_dict.update({
            'rpn_cls_loss': cls_loss,
            'rpn_loc1_loss': rpn_loc1_loss * 0.35,
            'rpn_loc2_loss': rpn_loc2_loss * 0.5,
            'rpn_os_loss': os_loss * 10,
            'rpn_dims_loss': rpn_dims_loss,
            'rpn_orients_loss': rpn_orients_loss
        })

        return loss_dict

    def loss_retina(self, prediction_dict, feed_dict):
        loss_dict = {}

        targets = prediction_dict[constants.KEY_TARGETS]
        cls_target = targets[constants.KEY_CLASSES]
        loc1_target = targets[constants.KEY_BOXES_2D]
        loc2_target = targets[constants.KEY_BOXES_2D_REFINE]
        os_target = targets[constants.KEY_OBJECTNESS]

        conf_weight = cls_target['weight']
        conf_target = cls_target['target']
        conf_target[conf_weight == 0] = -1

        os_preds = os_target['pred']
        os_target_ = os_target['target']
        os_weight = os_target['weight']
        os_target_[os_weight == 0] = -1

        total_loss = self.retina_loss(loc1_target['pred'], loc2_target['pred'],
                                      loc1_target['target'],
                                      cls_target['pred'], conf_target,
                                      os_preds, os_target_)
        loss_dict['total_loss'] = total_loss
        return loss_dict
예제 #6
0
class RPNModel(Model):
    def init_param(self, model_config):
        self.in_channels = model_config['din']
        self.post_nms_topN = model_config['post_nms_topN']
        self.pre_nms_topN = model_config['pre_nms_topN']
        self.nms_thresh = model_config['nms_thresh']
        self.use_focal_loss = model_config['use_focal_loss']

        # anchor generator
        self.anchor_generator = anchor_generators.build(
            model_config['anchor_generator_config'])
        self.num_anchors = self.anchor_generator.num_anchors
        self.nc_bbox_out = 4 * self.num_anchors
        self.nc_score_out = self.num_anchors * 2

        self.target_generators = TargetGenerator(
            model_config['target_generator_config'])

    def init_weights(self):
        self.truncated = False

        Filler.normal_init(self.rpn_conv, 0, 0.01, self.truncated)
        Filler.normal_init(self.rpn_cls_score, 0, 0.01, self.truncated)
        Filler.normal_init(self.rpn_bbox_pred, 0, 0.01, self.truncated)

    def init_modules(self):
        # define the convrelu layers processing input feature map
        self.rpn_conv = nn.Conv2d(self.in_channels, 512, 3, 1, 1, bias=True)

        # define bg/fg classifcation score layer
        self.rpn_cls_score = nn.Conv2d(512, self.nc_score_out, 1, 1, 0)

        # define anchor box offset prediction layer

        bbox_feat_channels = 512
        self.rpn_bbox_pred = nn.Conv2d(bbox_feat_channels, self.nc_bbox_out, 1,
                                       1, 0)

        # bbox
        self.rpn_bbox_loss = nn.SmoothL1Loss(reduction='none')

        # cls
        if self.use_focal_loss:
            self.rpn_cls_loss = FocalLoss(2, gamma=2, alpha=0.25)
        else:
            self.rpn_cls_loss = nn.CrossEntropyLoss(reduction='none')

    def generate_proposal(self, rpn_cls_probs, anchors, rpn_bbox_preds,
                          im_info):
        # TODO create a new Function
        """
        Args:
        rpn_cls_probs: FloatTensor,shape(N,2*num_anchors,H,W)
        rpn_bbox_preds: FloatTensor,shape(N,num_anchors*4,H,W)
        anchors: FloatTensor,shape(N,4,H,W)

        Returns:
        proposals_batch: FloatTensor, shape(N,post_nms_topN,4)
        fg_probs_batch: FloatTensor, shape(N,post_nms_topN)
        """
        # assert len(
        # rpn_bbox_preds) == 1, 'just one feature maps is supported now'
        # rpn_bbox_preds = rpn_bbox_preds[0]
        # do not backward
        rpn_cls_probs = rpn_cls_probs.detach()
        rpn_bbox_preds = rpn_bbox_preds.detach()

        batch_size = rpn_bbox_preds.shape[0]
        rpn_bbox_preds = rpn_bbox_preds.permute(0, 2, 3, 1).contiguous()
        # shape(N,H*W*num_anchors,4)
        rpn_bbox_preds = rpn_bbox_preds.view(batch_size, -1, 4)

        coders = bbox_coders.build(
            self.target_generators.target_generator_config['coder_config'])
        proposals = coders.decode_batch(rpn_bbox_preds, anchors)

        # filer and clip
        proposals = box_ops.clip_boxes(proposals, im_info)

        # fg prob
        fg_probs = rpn_cls_probs[:, self.num_anchors:, :, :]
        fg_probs = fg_probs.permute(0, 2, 3,
                                    1).contiguous().view(batch_size, -1)

        # sort fg
        _, fg_probs_order = torch.sort(fg_probs, dim=1, descending=True)

        # fg_probs_batch = torch.zeros(batch_size,
        # self.post_nms_topN).type_as(rpn_cls_probs)
        proposals_batch = torch.zeros(batch_size, self.post_nms_topN,
                                      4).type_as(rpn_bbox_preds)
        proposals_order = torch.zeros(
            batch_size, self.post_nms_topN).fill_(-1).type_as(fg_probs_order)

        for i in range(batch_size):
            proposals_single = proposals[i]
            fg_probs_single = fg_probs[i]
            fg_order_single = fg_probs_order[i]
            # pre nms
            if self.pre_nms_topN > 0:
                fg_order_single = fg_order_single[:self.pre_nms_topN]
            proposals_single = proposals_single[fg_order_single]
            fg_probs_single = fg_probs_single[fg_order_single]

            # nms
            keep_idx_i = nms(proposals_single, fg_probs_single,
                             self.nms_thresh)
            keep_idx_i = keep_idx_i.long().view(-1)

            # post nms
            if self.post_nms_topN > 0:
                keep_idx_i = keep_idx_i[:self.post_nms_topN]
            proposals_single = proposals_single[keep_idx_i, :]
            fg_probs_single = fg_probs_single[keep_idx_i]
            fg_order_single = fg_order_single[keep_idx_i]

            # padding 0 at the end.
            num_proposal = keep_idx_i.numel()
            proposals_batch[i, :num_proposal, :] = proposals_single
            # fg_probs_batch[i, :num_proposal] = fg_probs_single
            proposals_order[i, :num_proposal] = fg_order_single
        return proposals_batch, proposals_order

    def forward(self, bottom_blobs):
        base_feat = bottom_blobs['base_feat']
        batch_size = base_feat.shape[0]
        im_info = bottom_blobs[constants.KEY_IMAGE_INFO]

        # rpn conv
        rpn_conv = F.relu(self.rpn_conv(base_feat), inplace=True)

        # rpn cls score
        # shape(N,2*num_anchors,H,W)
        rpn_cls_scores = self.rpn_cls_score(rpn_conv)

        # rpn cls prob shape(N,2*num_anchors,H,W)
        rpn_cls_score_reshape = rpn_cls_scores.view(batch_size, 2, -1)
        rpn_cls_probs = F.softmax(rpn_cls_score_reshape, dim=1)
        rpn_cls_probs = rpn_cls_probs.view_as(rpn_cls_scores)
        # import ipdb
        # ipdb.set_trace()

        # rpn bbox pred
        # shape(N,4*num_anchors,H,W)
        rpn_bbox_preds = self.rpn_bbox_pred(rpn_conv)

        # generate anchors
        feature_map_list = [base_feat.size()[-2:]]
        anchors = self.anchor_generator.generate(feature_map_list,
                                                 im_info[0][:-1])

        anchors = anchors.unsqueeze(0).repeat(batch_size, 1, 1)

        ###############################
        # Proposal
        ###############################
        # note that proposals_order is used for track transform of propsoals
        proposals_batch, proposals_order = self.generate_proposal(
            rpn_cls_probs, anchors, rpn_bbox_preds, im_info)
        #  batch_idx = torch.arange(batch_size).view(batch_size, 1).expand(
        #  -1, proposals_batch.shape[1]).type_as(proposals_batch)
        #  rois_batch = torch.cat((batch_idx.unsqueeze(-1), proposals_batch),
        #  dim=2)

        # if self.training:
        # label_boxes_2d = bottom_blobs[constants.KEY_LABEL_BOXES_2D]
        # proposals_batch = self.append_gt(proposals_batch, label_boxes_2d)

        rpn_cls_scores = rpn_cls_scores.view(batch_size, 2, -1,
                                             rpn_cls_scores.shape[2],
                                             rpn_cls_scores.shape[3])
        rpn_cls_scores = rpn_cls_scores.permute(0, 3, 4, 2,
                                                1).contiguous().view(
                                                    batch_size, -1, 2)

        # postprocess
        rpn_cls_probs = rpn_cls_probs.view(batch_size, 2, -1,
                                           rpn_cls_probs.shape[2],
                                           rpn_cls_probs.shape[3])
        rpn_cls_probs = rpn_cls_probs.permute(0, 3, 4, 2, 1).contiguous().view(
            batch_size, -1, 2)

        rpn_bbox_preds = rpn_bbox_preds.permute(0, 2, 3, 1).contiguous()
        # shape(N,H*W*num_anchors,4)
        rpn_bbox_preds = rpn_bbox_preds.view(batch_size, -1, 4)

        predict_dict = {
            'proposals': proposals_batch,
            'rpn_cls_scores': rpn_cls_scores,
            #  'rois_batch': rois_batch,
            'anchors': anchors,

            # used for loss
            'rpn_bbox_preds': rpn_bbox_preds,
            'rpn_cls_probs': rpn_cls_probs,
            'proposals_order': proposals_order,
        }

        return predict_dict

    def append_gt(self, proposals_batch, label_boxes_2d):
        """
        Args:
            proposals_batch: shape(N, M, 4)
            label_boxes_2d: shape(N, m, 4)
            num_instances: shape(N,) valid num of bboxes in each image
        Returns:
            proposals_batch: shape(N, M+m, 4)
        """
        return torch.cat([proposals_batch, label_boxes_2d], dim=1)

    def loss(self, prediction_dict, feed_dict):
        # loss for cls
        loss_dict = {}
        anchors = prediction_dict['anchors']
        anchors_dict = {}
        anchors_dict[constants.KEY_PRIMARY] = anchors
        anchors_dict[
            constants.KEY_BOXES_2D] = prediction_dict['rpn_bbox_preds']
        anchors_dict[constants.KEY_CLASSES] = prediction_dict['rpn_cls_scores']

        gt_dict = {}
        gt_dict[constants.KEY_PRIMARY] = feed_dict[
            constants.KEY_LABEL_BOXES_2D]
        gt_dict[constants.KEY_CLASSES] = None
        gt_dict[constants.KEY_BOXES_2D] = None

        auxiliary_dict = {}
        auxiliary_dict[constants.KEY_BOXES_2D] = feed_dict[
            constants.KEY_LABEL_BOXES_2D]
        gt_labels = feed_dict[constants.KEY_LABEL_CLASSES]
        auxiliary_dict[constants.KEY_CLASSES] = torch.ones_like(gt_labels)
        auxiliary_dict[constants.KEY_NUM_INSTANCES] = feed_dict[
            constants.KEY_NUM_INSTANCES]
        auxiliary_dict[constants.KEY_PROPOSALS] = anchors

        # import ipdb
        # ipdb.set_trace()
        _, targets, _ = self.target_generators.generate_targets(
            anchors_dict, gt_dict, auxiliary_dict, subsample=False)

        cls_target = targets[constants.KEY_CLASSES]
        reg_target = targets[constants.KEY_BOXES_2D]

        # loss

        if self.use_focal_loss:
            # when using focal loss, dont normalize it by all samples
            cls_targets = cls_target['target']
            pos = cls_targets > 0  # [N,#anchors]
            num_pos = pos.long().sum().clamp(min=1).float()
            rpn_cls_loss = common_loss.calc_loss(
                self.rpn_cls_loss, cls_target, normalize=False) / num_pos
        else:
            rpn_cls_loss = common_loss.calc_loss(self.rpn_cls_loss, cls_target)
        rpn_reg_loss = common_loss.calc_loss(self.rpn_bbox_loss, reg_target)
        loss_dict.update({
            'rpn_cls_loss': rpn_cls_loss,
            'rpn_reg_loss': rpn_reg_loss
        })

        return loss_dict