예제 #1
0
    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']

        # some submodule config
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])
예제 #2
0
    def init_params(self, model_config):
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.multibox_cfg = model_config['multibox_cfg']

        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])
예제 #3
0
    def init_param(self, model_config):
        self.in_channels = model_config['din']
        self.post_nms_topN = model_config['post_nms_topN']
        self.pre_nms_topN = model_config['pre_nms_topN']
        self.nms_thresh = model_config['nms_thresh']
        self.use_score = model_config['use_score']
        self.rpn_batch_size = model_config['rpn_batch_size']
        self.use_focal_loss = model_config['use_focal_loss']

        # sampler
        # self.sampler = HardNegativeSampler(model_config['sampler_config'])
        # self.sampler = BalancedSampler(model_config['sampler_config'])
        self.sampler = DetectionSampler(model_config['sampler_config'])

        # anchor generator
        self.anchor_generator = AnchorGenerator(
            model_config['anchor_generator_config'])
        self.num_anchors = self.anchor_generator.num_anchors
        self.nc_bbox_out = 4 * self.num_anchors
        self.nc_score_out = self.num_anchors * 2

        # target assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # bbox coder
        self.bbox_coder = self.target_assigner.bbox_coder

        self.use_iou = model_config.get('use_iou')
예제 #4
0
    def __init__(self, layer_config):
        super().__init__()
        # some parameters
        self.rpn_positive_weight = layer_config['rpn_positive_weight']
        self.rpn_negative_overlaps = layer_config['rpn_negative_overlaps']
        self.rpn_positive_overlaps = layer_config['rpn_positive_overlaps']
        self.rpn_batch_size = layer_config['rpn_batch_size']
        # subsample score and iou or subsample score only
        self.subsample_twice = layer_config['subsample_twice']
        self.subsample_type = layer_config['subsample_type']

        self.target_assigner = TargetAssigner()
        self.sampler = Sampler(self.subsample_type)
예제 #5
0
def build_target_assigner(target_assigner_config, bv_range, box_coder):
    """Builds a tensor dictionary based on the InputReader config.

    Args:
        input_reader_config: A input_reader_pb2.InputReader object.

    Returns:
        A tensor dict based on the input_reader_config.

    Raises:
        ValueError: On invalid input reader proto.
        ValueError: If no input paths are specified.
    """

    anchor_cfg = target_assigner_config.ANCHOR_GENERATORS
    anchor_generators = []
    for a_cfg in anchor_cfg:

        anchor_generator = build_anchor_generator(a_cfg)
        anchor_generators.append(anchor_generator)

    similarity_calc = build_similarity_calculator(
        target_assigner_config.REGION_SIMILARITY_CALCULATOR)
    positive_fraction = target_assigner_config.SAMPLE_POSITIVE_FRACTION
    if positive_fraction < 0:
        positive_fraction = None
    target_assigner = TargetAssigner(
        box_coder=box_coder,
        anchor_generators=anchor_generators,
        region_similarity_calculator=similarity_calc,
        positive_fraction=positive_fraction,
        sample_size=target_assigner_config.SAMPLE_SIZE)
    return target_assigner
예제 #6
0
    def init_param(self, model_config):

        self.feat_size = model_config['common_feat_size']
        self.batch_size = model_config['batch_size']
        self.sample_size = model_config['sample_size']
        self.pooling_size = model_config['pooling_size']
        self.n_classes = model_config['num_classes']
        self.use_focal_loss = model_config['use_focal_loss']
        self.feature_extractor_config = model_config[
            'feature_extractor_config']

        self.voxel_generator = VoxelGenerator(
            model_config['voxel_generator_config'])
        self.voxel_generator.init_voxels()

        self.integral_map_generator = IntegralMapGenerator()

        self.oft_target_assigner = OFTargetAssigner(
            model_config['target_assigner_config'])

        self.target_assigner = TargetAssigner(
            model_config['eval_target_assigner_config'])
        self.target_assigner.analyzer.append_gt = False

        self.sampler = DetectionSampler(model_config['sampler_config'])

        self.bbox_coder = self.oft_target_assigner.bbox_coder

        # find the most expensive operators
        self.profiler = Profiler()

        # self.multibin = model_config['multibin']
        self.num_bins = model_config['num_bins']

        self.reg_channels = 3 + 3 + self.num_bins * 4

        # score, pos, dim, ang
        self.rcnn_output_channels = self.n_classes + self.reg_channels

        self.rpn_output_channels = 2 + 3 + 3

        nms_deltas = model_config.get('nms_deltas')
        if nms_deltas is None:
            nms_deltas = 1
        self.nms_deltas = nms_deltas
예제 #7
0
    def init_param(self, model_config):
        self.feature_extractor_config = model_config['feature_extractor_config']
        self.multibox_cfg = [3, 3, 3, 3, 3, 3]
        self.n_classes = len(model_config['classes'])
        self.sampler = DetectionSampler(model_config['sampler_config'])
        self.batch_size = model_config['batch_size']
        self.use_focal_loss = model_config['use_focal_loss']
        # self.multibox_cfg = model_config['multibox_config']

        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # import ipdb
        # ipdb.set_trace()
        self.anchor_generator = AnchorGenerator(
            model_config['anchor_generator_config'])

        self.bbox_coder = self.target_assigner.bbox_coder
예제 #8
0
class OFTModel(Model):
    def forward(self, feed_dict):
        # import ipdb
        # ipdb.set_trace()

        self.profiler.start('1')
        self.voxel_generator.proj_voxels_3dTo2d(feed_dict['p2'],
                                                feed_dict['im_info'])
        self.profiler.end('1')

        self.profiler.start('2')
        img_feat_maps = self.feature_extractor.forward(feed_dict['img'])
        self.profiler.end('2')

        self.profiler.start('3')
        img_feat_maps = self.feature_preprocess(img_feat_maps)
        self.profiler.end('3')

        self.profiler.start('4')
        integral_maps = self.generate_integral_maps(img_feat_maps)
        self.profiler.end('4')

        # import ipdb
        # ipdb.set_trace()
        self.profiler.start('5')
        oft_maps = self.generate_oft_maps(integral_maps)
        self.profiler.end('5')

        self.profiler.start('6')
        bev_feat_maps = self.feature_extractor.bev_feature(oft_maps)
        self.profiler.end('6')

        # pred output
        # shape (NCHW)
        self.profiler.start('7')
        output_maps = self.output_head(bev_feat_maps)
        self.profiler.end('7')

        # shape(N,M,out_channels)
        pred_3d = output_maps.permute(0, 2, 3, 1).contiguous().view(
            self.batch_size, -1, self.output_channels)

        pred_boxes_3d = pred_3d[:, :, self.n_classes:]
        pred_scores_3d = pred_3d[:, :, :self.n_classes]

        pred_probs_3d = F.softmax(pred_scores_3d, dim=-1)
        # import ipdb
        # ipdb.set_trace()
        self.add_feat('pred_scores_3d', output_maps[:, 1:2, :, :])
        self.add_feat('bev_feat_map', bev_feat_maps)

        if not self.training:
            # import ipdb
            # ipdb.set_trace()
            voxel_centers = self.voxel_generator.voxel_centers
            D = self.voxel_generator.lattice_dims[1]
            voxel_centers = voxel_centers.view(-1, D, 3)[:, 0, :]
            # pred_boxes_3d = self.bbox_coder.decode_batch_bbox(voxel_centers,
            # pred_boxes_3d)
            # decode angle
            angles_oritations = self.bbox_coder.decode_batch_angle_multibin(
                pred_boxes_3d[:, :, 6:], self.angle_loss.bin_centers,
                self.num_bins)

            pred_boxes_3d = self.bbox_coder.decode_batch_bbox(
                voxel_centers, pred_boxes_3d[:, :, :6])
            # import ipdb
            # ipdb.set_trace()
            # random_value = torch.rand(angles_oritations.shape)
            # angles_oritations = random_value.type_as(
            # angles_oritations) * angles_oritations

            pred_boxes_3d = torch.cat([pred_boxes_3d, angles_oritations],
                                      dim=-1)

            # gussian filter probs map
            # reshape first
            shape = output_maps.shape[-2:]
            fg_mask = pred_probs_3d[0, :, 1].view(shape).detach().cpu().numpy()

            # then smooth
            from scipy.ndimage import gaussian_filter
            smoothed_fg_mask = gaussian_filter(fg_mask, sigma=self.nms_deltas)

            smoothed_fg_mask = torch.tensor(smoothed_fg_mask).type_as(
                pred_probs_3d)

            # nms
            smoothed_fg_mask = self.nms_map(smoothed_fg_mask)

            # assign back to tensor
            pred_probs_3d[0, :, 1] = smoothed_fg_mask.view(-1)

            # reset bg according to fg
            pred_probs_3d[0, :, 0] = 1 - pred_probs_3d[0, :, 1]

        prediction_dict = {}
        prediction_dict['pred_boxes_3d'] = pred_boxes_3d
        # prediction_dict['pred_scores_3d'] = pred_scores_3d
        prediction_dict['pred_probs_3d'] = pred_probs_3d

        return prediction_dict

    def nms_map(self, smoothed_fg_mask):
        """
        supress the neibor
        """

        directions = [-1, 0, 1]
        shape = smoothed_fg_mask.shape
        orig_index = (torch.arange(shape[0]).cuda().long(),
                      torch.arange(shape[1]).cuda().long())
        orig_index = ops.meshgrid(orig_index[1], orig_index[0])
        orig_index = [orig_index[1], orig_index[0]]
        dest_indexes = []
        for i in directions:
            for j in directions:
                dest_index = (orig_index[0] + directions[i],
                              orig_index[1] + directions[j])
                dest_indexes.append(dest_index)

        nms_filter = torch.ones_like(smoothed_fg_mask).byte()
        orig_fg_mask = smoothed_fg_mask

        # pad fg mask first to prevent out of boundary
        padded_smoothed_fg_mask = torch.zeros(
            (shape[0] + 1, shape[1] + 1)).type_as(smoothed_fg_mask)
        padded_smoothed_fg_mask[:-1, :-1] = smoothed_fg_mask

        # import ipdb
        # ipdb.set_trace()
        for dest_index in dest_indexes:
            nms_filter = nms_filter & (
                orig_fg_mask >=
                padded_smoothed_fg_mask[dest_index].view_as(orig_fg_mask))

        # surpress
        smoothed_fg_mask[~nms_filter] = 0
        return smoothed_fg_mask

    def feature_preprocess(self, feat_maps):
        # import ipdb
        # ipdb.set_trace()
        reduced_feat_maps = []
        for ind, feat_map in enumerate(feat_maps):
            reduced_feat_map = self.feats_reduces[ind](feat_map)
            reduced_feat_maps.append(reduced_feat_map)
        return reduced_feat_maps

    def generate_integral_maps(self, img_feat_maps):
        integral_maps = []
        for img_feat_map in img_feat_maps:
            integral_maps.append(
                self.integral_map_generator.generate(img_feat_map))

        return integral_maps

    def generate_oft_maps(self, integral_maps):
        # shape(N,4)
        normalized_voxel_proj_2d = self.voxel_generator.normalized_voxel_proj_2d
        # for i in range(voxel_proj_2d.shape[0]):
        multiscale_img_feat = []
        for integral_map in integral_maps:
            multiscale_img_feat.append(
                self.integral_map_generator.calc(integral_map,
                                                 normalized_voxel_proj_2d))

        # shape(N,C,HWD)
        fusion_feat = multiscale_img_feat[0] + multiscale_img_feat[
            1] + multiscale_img_feat[2]
        depth_dim = self.voxel_generator.lattice_dims[1]
        height_dim = self.voxel_generator.lattice_dims[0]

        fusion_feat = fusion_feat.view(self.batch_size, self.feat_size, -1,
                                       depth_dim).permute(0, 3, 1,
                                                          2).contiguous()
        # shape(N,C,HW)
        oft_maps = self.feat_collapse(fusion_feat).view(
            self.batch_size, self.feat_size, height_dim, -1)

        return oft_maps

    def init_param(self, model_config):

        self.feat_size = model_config['common_feat_size']
        self.batch_size = model_config['batch_size']
        self.sample_size = model_config['sample_size']
        self.n_classes = model_config['num_classes']
        self.use_focal_loss = model_config['use_focal_loss']
        self.feature_extractor_config = model_config[
            'feature_extractor_config']

        self.voxel_generator = VoxelGenerator(
            model_config['voxel_generator_config'])
        self.voxel_generator.init_voxels()

        self.integral_map_generator = IntegralMapGenerator()

        self.oft_target_assigner = OFTargetAssigner(
            model_config['target_assigner_config'])

        self.target_assigner = TargetAssigner(
            model_config['eval_target_assigner_config'])
        self.target_assigner.analyzer.append_gt = False

        self.sampler = DetectionSampler(model_config['sampler_config'])

        self.bbox_coder = self.oft_target_assigner.bbox_coder

        # find the most expensive operators
        self.profiler = Profiler()

        # self.multibin = model_config['multibin']
        self.num_bins = model_config['num_bins']

        self.reg_channels = 3 + 3 + self.num_bins * 4

        # score, pos, dim, ang
        self.output_channels = self.n_classes + self.reg_channels

        nms_deltas = model_config.get('nms_deltas')
        if nms_deltas is None:
            nms_deltas = 1
        self.nms_deltas = nms_deltas

    def init_modules(self):
        """
        some modules
        """

        self.feature_extractor = OFTNetFeatureExtractor(
            self.feature_extractor_config)

        feats_reduce_1 = nn.Conv2d(128, self.feat_size, 1, 1, 0)
        feats_reduce_2 = nn.Conv2d(256, self.feat_size, 1, 1, 0)
        feats_reduce_3 = nn.Conv2d(512, self.feat_size, 1, 1, 0)
        self.feats_reduces = nn.ModuleList(
            [feats_reduce_1, feats_reduce_2, feats_reduce_3])

        self.feat_collapse = nn.Conv2d(8, 1, 1, 1, 0)

        self.output_head = nn.Conv2d(256 * 4, self.output_channels, 1, 1, 0)

        # loss
        self.reg_loss = nn.L1Loss(reduce=False)
        # self.reg_loss = nn.SmoothL1Loss(reduce=False)
        # if self.use_focal_loss:
        # self.conf_loss = FocalLoss(
        # self.n_classes, alpha=0.2, gamma=2, auto_alpha=False)
        # else:
        # self.conf_loss = nn.CrossEntropyLoss(reduce=False)
        self.conf_loss = nn.L1Loss(reduce=False)

        self.angle_loss = MultiBinLoss(num_bins=self.num_bins)

    def init_weights(self):
        self.feature_extractor.init_weights()

    def loss(self, prediction_dict, feed_dict):
        self.profiler.start('8')
        gt_boxes_3d = feed_dict['gt_boxes_3d']
        gt_labels = feed_dict['gt_labels']
        gt_boxes_ground_2d_rect = feed_dict['gt_boxes_ground_2d_rect']

        voxels_ground_2d = self.voxel_generator.proj_voxels_to_ground()
        voxel_centers = self.voxel_generator.voxel_centers
        D = self.voxel_generator.lattice_dims[1]
        voxel_centers = voxel_centers.view(-1, D, 3)[:, 0, :]

        # gt_boxes_3d = torch.cat([gt_boxes_3d[:,:,:3],],dim=-1)

        cls_weights, reg_weights, cls_targets, reg_targets = self.oft_target_assigner.assign(
            voxels_ground_2d, gt_boxes_ground_2d_rect, voxel_centers,
            gt_boxes_3d, gt_labels)

        # pred_boxes_3d = prediction_dict['pred_boxes_3d']
        ################################
        # subsample
        ################################

        # pos_indicator = reg_weights > 0
        # indicator = cls_weights > 0

        # rpn_cls_probs = prediction_dict['pred_probs_3d'][:, :, 1]
        # cls_criterion = rpn_cls_probs

        # batch_sampled_mask = self.sampler.subsample_batch(
        # self.sample_size,
        # pos_indicator,
        # criterion=cls_criterion,
        # indicator=indicator)

        # import ipdb
        # ipdb.set_trace()
        # batch_sampled_mask = batch_sampled_mask.type_as(cls_weights)
        # rpn_cls_weights = cls_weights[batch_sampled_mask]
        # rpn_reg_weights = reg_weights[batch_sampled_mask]
        # cls_targets = cls_targets[batch_sampled_mask]
        # reg_targets = reg_targets[batch_sampled_mask]

        # num_cls_coeff = (rpn_cls_weights > 0).sum(dim=-1)
        # import ipdb
        # ipdb.set_trace()
        num_reg_coeff = (reg_weights > 0).sum(dim=-1)
        # # check
        # #  assert num_cls_coeff, 'bug happens'
        # #  assert num_reg_coeff, 'bug happens'
        # if num_cls_coeff == 0:
        # num_cls_coeff = torch.ones([]).type_as(num_cls_coeff)
        if num_reg_coeff == 0:
            num_reg_coeff = torch.ones([]).type_as(num_reg_coeff)

        # import ipdb
        # ipdb.set_trace()
        # cls loss
        rpn_cls_probs = prediction_dict['pred_probs_3d'][:, :, -1]
        rpn_cls_loss = self.conf_loss(rpn_cls_probs, cls_targets)
        rpn_cls_loss = rpn_cls_loss.view_as(cls_weights)
        rpn_cls_loss = rpn_cls_loss * cls_weights
        rpn_cls_loss = rpn_cls_loss.mean(dim=-1)

        # bbox loss
        rpn_bbox_preds = prediction_dict['pred_boxes_3d']
        rpn_reg_loss = self.reg_loss(rpn_bbox_preds[:, :, :6],
                                     reg_targets[:, :, :-1])
        rpn_reg_loss = rpn_reg_loss * reg_weights.unsqueeze(-1)
        num_reg_coeff = num_reg_coeff.type_as(reg_weights)

        # angle_loss
        angle_loss, angle_tp_mask = self.angle_loss(rpn_bbox_preds[:, :, 6:],
                                                    reg_targets[:, :, -1:])
        rpn_angle_loss = angle_loss * reg_weights

        # split reg loss
        dim_loss = rpn_reg_loss[:, :, :3].sum(dim=-1).sum(
            dim=-1) / num_reg_coeff
        pos_loss = rpn_reg_loss[:, :,
                                3:6].sum(dim=-1).sum(dim=-1) / num_reg_coeff
        angle_loss = rpn_angle_loss.sum(dim=-1).sum(dim=-1) / num_reg_coeff

        prediction_dict['rcnn_reg_weights'] = reg_weights

        loss_dict = {}

        loss_dict['rpn_cls_loss'] = rpn_cls_loss
        # loss_dict['rpn_bbox_loss'] = rpn_reg_loss
        # split bbox loss instead of fusing them
        loss_dict['dim_loss'] = dim_loss
        loss_dict['pos_loss'] = pos_loss
        loss_dict['angle_loss'] = angle_loss

        self.profiler.end('8')

        # recall
        # final_boxes = self.bbox_coder.decode_batch(rpn_bbox_preds, )
        # self.target_assigner.assign(final_boxes, gt_boxes)

        # import ipdb
        # ipdb.set_trace()
        voxel_centers = self.voxel_generator.voxel_centers
        D = self.voxel_generator.lattice_dims[1]
        voxel_centers = voxel_centers.view(-1, D, 3)[:, 0, :]
        # import ipdb
        # ipdb.set_trace()
        # decode bbox
        pred_boxes_3d = self.bbox_coder.decode_batch_bbox(
            voxel_centers, rpn_bbox_preds[:, :, :6])
        # decode angle
        angles_oritations = self.bbox_coder.decode_batch_angle_multibin(
            rpn_bbox_preds[:, :, 6:], self.angle_loss.bin_centers,
            self.num_bins)
        pred_boxes_3d = torch.cat([pred_boxes_3d, angles_oritations], dim=-1)

        # import ipdb
        # ipdb.set_trace()
        # select the top n
        order = torch.sort(rpn_cls_probs, descending=True)[1]
        topn = 1000
        order = order[:, :topn]
        rpn_cls_probs = rpn_cls_probs[0][order[0]].unsqueeze(0)
        pred_boxes_3d = pred_boxes_3d[0][order[0]].unsqueeze(0)

        target = {
            'dimension': pred_boxes_3d[0, :, :3],
            'location': pred_boxes_3d[0, :, 3:6],
            'ry': pred_boxes_3d[0, :, 6]
        }

        boxes_2d = Projector.proj_box_3to2img(target, feed_dict['p2'])
        gt_boxes = feed_dict['gt_boxes']
        num_gt = gt_labels.numel()
        self.target_assigner.assign(boxes_2d, gt_boxes, eval_thresh=0.7)

        fake_match = self.target_assigner.analyzer.match
        # import ipdb
        # ipdb.set_trace()
        self.target_assigner.analyzer.analyze_ap(fake_match,
                                                 rpn_cls_probs,
                                                 num_gt,
                                                 thresh=0.1)

        # import ipdb
        # ipdb.set_trace()
        # angle stats
        angle_tp_mask = angle_tp_mask[reg_weights > 0]
        angles_tp_num = angle_tp_mask.int().sum().item()
        angles_all_num = angle_tp_mask.numel()

        self.target_assigner.stat.update({
            'cls_orient_2s_all_num': angles_all_num,
            'cls_orient_2s_tp_num': angles_tp_num
        })

        return loss_dict
예제 #9
0
class GateRPNModel(Model):
    def init_param(self, model_config):
        self.in_channels = model_config['din']
        self.post_nms_topN = model_config['post_nms_topN']
        self.pre_nms_topN = model_config['pre_nms_topN']
        self.nms_thresh = model_config['nms_thresh']
        self.use_score = model_config['use_score']
        self.rpn_batch_size = model_config['rpn_batch_size']
        self.use_focal_loss = model_config['use_focal_loss']
        self.gate_thresh = model_config['gate_thresh']

        # sampler
        # self.sampler = HardNegativeSampler(model_config['sampler_config'])
        # self.sampler = BalancedSampler(model_config['sampler_config'])
        self.sampler = DetectionSampler(model_config['sampler_config'])

        # anchor generator
        self.anchor_generator = AnchorGenerator(
            model_config['anchor_generator_config'])
        self.num_anchors = self.anchor_generator.num_anchors
        self.nc_bbox_out = 4 * self.num_anchors
        self.nc_score_out = self.num_anchors * 2

        # target assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # bbox coder
        self.bbox_coder = self.target_assigner.bbox_coder

    def init_weights(self):
        self.truncated = False

        Filler.normal_init(self.rpn_conv, 0, 0.01, self.truncated)
        Filler.normal_init(self.rpn_cls_score, 0, 0.01, self.truncated)
        Filler.normal_init(self.rpn_bbox_pred, 0, 0.01, self.truncated)

    def init_modules(self):
        # define the convrelu layers processing input feature map
        self.rpn_conv = nn.Conv2d(self.in_channels, 512, 3, 1, 1, bias=True)

        # define bg/fg classifcation score layer
        self.rpn_cls_score = nn.Conv2d(512, self.nc_score_out, 1, 1, 0)

        # define anchor box offset prediction layer

        if self.use_score:
            bbox_feat_channels = 512 + 2
            self.nc_bbox_out /= self.num_anchors
        else:
            bbox_feat_channels = 512
        self.rpn_bbox_pred = nn.Conv2d(bbox_feat_channels, self.nc_bbox_out, 1,
                                       1, 0)

        # bbox
        self.rpn_bbox_loss = nn.modules.loss.SmoothL1Loss(reduce=False)

        # cls
        if self.use_focal_loss:
            self.rpn_cls_loss = FocalLoss(2)
        else:
            self.rpn_cls_loss = functools.partial(F.cross_entropy,
                                                  reduce=False)

    def get_rpn_cls_probs(self, bbox_pred, anchors=None):
        """
        Note that all inputs have no gradients
        Args:
        bbox_pred: shape (N,M,4)
        anchors: shape (M,4)
        Returns:
        distance: shape(N,M)
        """
        # shape(N,M,4)
        # distances = self.distance_similarity_calc.compare_batch(bbox, gt_boxes)
        # anchors = anchors.expand_as(bbox_pred)
        # widths = anchors[:, :, 2] - anchors[:, :, 0] + 1.0
        # heights = anchors[:, :, 3] - anchors[:, :, 1] + 1.0
        # dx = bbox_pred[:, :, 0] * widths
        # dy = bbox_pred[:, :, 1] * heights
        dx = bbox_pred[:, :, 0]
        dy = bbox_pred[:, :, 1]
        distance = torch.sqrt(dx * dx + dy * dy)
        theta = 1e-5
        #  return 1.0 / (distance + theta), distance
        return 1.0 / (distance + theta)

    def generate_proposal(self, rpn_cls_probs, anchors, rpn_bbox_preds,
                          im_info):
        # TODO create a new Function
        """
        Args:
        rpn_cls_probs: FloatTensor,shape(N,2*num_anchors,H,W)
        rpn_bbox_preds: FloatTensor,shape(N,num_anchors*4,H,W)
        anchors: FloatTensor,shape(N,4,H,W)

        Returns:
        proposals_batch: FloatTensor, shape(N,post_nms_topN,4)
        fg_probs_batch: FloatTensor, shape(N,post_nms_topN)
        """
        # assert len(
        # rpn_bbox_preds) == 1, 'just one feature maps is supported now'
        # rpn_bbox_preds = rpn_bbox_preds[0]
        anchors = anchors[0]
        # do not backward
        anchors = anchors
        rpn_cls_probs = rpn_cls_probs.detach()
        rpn_bbox_preds = rpn_bbox_preds.detach()

        batch_size = rpn_bbox_preds.shape[0]
        rpn_bbox_preds = rpn_bbox_preds.permute(0, 2, 3, 1).contiguous()
        # shape(N,H*W*num_anchors,4)
        rpn_bbox_preds = rpn_bbox_preds.view(batch_size, -1, 4)
        # apply deltas to anchors to decode
        # loop here due to many features maps
        # proposals = []
        # for rpn_bbox_preds_single_map, anchors_single_map in zip(
        # rpn_bbox_preds, anchors):
        # proposals.append(
        # self.bbox_coder.decode(rpn_bbox_preds_single_map,
        # anchors_single_map))
        # proposals = torch.cat(proposals, dim=1)

        proposals = self.bbox_coder.decode_batch(rpn_bbox_preds, anchors)

        # filer and clip
        proposals = box_ops.clip_boxes(proposals, im_info)

        # fg prob
        gate = rpn_cls_probs[:, self.num_anchors:, :, :]
        gate = gate.permute(0, 2, 3, 1).contiguous().view(batch_size, -1)
        fg_probs = self.get_rpn_cls_probs(rpn_bbox_preds, anchors)
        fg_probs[gate < self.gate_thresh] = 0

        # sort fg
        _, fg_probs_order = torch.sort(fg_probs, dim=1, descending=True)

        # fg_probs_batch = torch.zeros(batch_size,
        # self.post_nms_topN).type_as(rpn_cls_probs)
        proposals_batch = torch.zeros(batch_size, self.post_nms_topN,
                                      4).type_as(rpn_bbox_preds)
        proposals_order = torch.zeros(
            batch_size, self.post_nms_topN).fill_(-1).type_as(fg_probs_order)

        for i in range(batch_size):
            proposals_single = proposals[i]
            fg_probs_single = fg_probs[i]
            fg_order_single = fg_probs_order[i]
            # pre nms
            if self.pre_nms_topN > 0:
                fg_order_single = fg_order_single[:self.pre_nms_topN]
            proposals_single = proposals_single[fg_order_single]
            fg_probs_single = fg_probs_single[fg_order_single]

            # nms
            keep_idx_i = nms(
                torch.cat((proposals_single, fg_probs_single.unsqueeze(1)), 1),
                self.nms_thresh)
            keep_idx_i = keep_idx_i.long().view(-1)

            # post nms
            if self.post_nms_topN > 0:
                keep_idx_i = keep_idx_i[:self.post_nms_topN]
            proposals_single = proposals_single[keep_idx_i, :]
            fg_probs_single = fg_probs_single[keep_idx_i]
            fg_order_single = fg_order_single[keep_idx_i]

            # padding 0 at the end.
            num_proposal = keep_idx_i.numel()
            proposals_batch[i, :num_proposal, :] = proposals_single
            # fg_probs_batch[i, :num_proposal] = fg_probs_single
            proposals_order[i, :num_proposal] = fg_order_single
        return proposals_batch, proposals_order, fg_probs

    def forward(self, bottom_blobs):
        base_feat = bottom_blobs['base_feat']
        batch_size = base_feat.shape[0]
        gt_boxes = bottom_blobs['gt_boxes']
        im_info = bottom_blobs['im_info']

        # rpn conv
        rpn_conv = F.relu(self.rpn_conv(base_feat), inplace=True)

        # rpn cls score
        # shape(N,2*num_anchors,H,W)
        rpn_cls_scores = self.rpn_cls_score(rpn_conv)

        # rpn cls prob shape(N,2*num_anchors,H,W)
        rpn_cls_score_reshape = rpn_cls_scores.view(batch_size, 2, -1)
        gate_probs = F.softmax(rpn_cls_score_reshape, dim=1)
        gate_probs = gate_probs.view_as(rpn_cls_scores)
        # import ipdb
        # ipdb.set_trace()

        # rpn bbox pred
        # shape(N,4*num_anchors,H,W)
        if self.use_score:
            # shape (N,2,num_anchoros*H*W)
            rpn_cls_scores = rpn_cls_score_reshape.permute(0, 2, 1)
            rpn_bbox_preds = []
            for i in range(self.num_anchors):
                rpn_bbox_feat = torch.cat(
                    [rpn_conv, rpn_cls_scores[:, ::self.num_anchors, :, :]],
                    dim=1)
                rpn_bbox_preds.append(self.rpn_bbox_pred(rpn_bbox_feat))
            rpn_bbox_preds = torch.cat(rpn_bbox_preds, dim=1)
        else:
            # get rpn offsets to the anchor boxes
            rpn_bbox_preds = self.rpn_bbox_pred(rpn_conv)
            # rpn_bbox_preds = [rpn_bbox_preds]

        # generate anchors
        feature_map_list = [base_feat.size()[-2:]]
        anchors = self.anchor_generator.generate(feature_map_list)

        ###############################
        # Proposal
        ###############################
        # note that proposals_order is used for track transform of propsoals
        proposals_batch, proposals_order, fg_probs = self.generate_proposal(
            gate_probs, anchors, rpn_bbox_preds, im_info)
        batch_idx = torch.arange(batch_size).view(batch_size, 1).expand(
            -1, proposals_batch.shape[1]).type_as(proposals_batch)
        rois_batch = torch.cat((batch_idx.unsqueeze(-1), proposals_batch),
                               dim=2)

        if self.training:
            rois_batch = self.append_gt(rois_batch, gt_boxes)

        rpn_cls_scores = rpn_cls_scores.view(batch_size, 2, -1,
                                             rpn_cls_scores.shape[2],
                                             rpn_cls_scores.shape[3])
        rpn_cls_scores = rpn_cls_scores.permute(0, 3, 4, 2,
                                                1).contiguous().view(
                                                    batch_size, -1, 2)

        # postprocess
        #  gate_probs = gate_probs.view(batch_size, 2, -1, gate_probs.shape[2],
        #  gate_probs.shape[3])
        #  gate_probs = gate_probs.permute(0, 3, 4, 2, 1).contiguous().view(
        #  batch_size, -1, 2)
        predict_dict = {
            'proposals_batch': proposals_batch,
            'rpn_cls_scores': rpn_cls_scores,
            'rois_batch': rois_batch,
            'anchors': anchors,

            # used for loss
            'rpn_bbox_preds': rpn_bbox_preds,
            'proposals_order': proposals_order,
            'fg_probs': fg_probs,
        }

        return predict_dict

    def append_gt(self, rois_batch, gt_boxes):
        ################################
        # append gt_boxes to rois_batch for losses
        ################################
        # may be some bugs here
        gt_boxes_append = torch.zeros(gt_boxes.shape[0], gt_boxes.shape[1],
                                      5).type_as(gt_boxes)
        gt_boxes_append[:, :, 1:5] = gt_boxes[:, :, :4]
        # cat gt_boxes to rois_batch
        rois_batch = torch.cat([rois_batch, gt_boxes_append], dim=1)
        return rois_batch

    def loss(self, prediction_dict, feed_dict):
        # loss for cls
        loss_dict = {}

        gt_boxes = feed_dict['gt_boxes']

        anchors = prediction_dict['anchors']

        assert len(anchors) == 1, 'just one feature maps is supported now'
        anchors = anchors[0]

        #################################
        # target assigner
        ################################
        # no need gt labels here,it just a binary classifcation problem
        #  import ipdb
        #  ipdb.set_trace()
        rpn_cls_targets, rpn_reg_targets, \
            rpn_cls_weights, rpn_reg_weights = \
            self.target_assigner.assign(anchors, gt_boxes, gt_labels=None)

        ################################
        # subsample
        ################################
        pos_indicator = rpn_cls_targets > 0
        indicator = rpn_cls_weights > 0

        use_iou_for_criteron = True
        if use_iou_for_criteron:
            cls_criterion = self.target_assigner.matcher.assigned_overlaps_batch
        else:
            fg_probs = prediction_dict['fg_probs']
            cls_criterion = fg_probs

        batch_sampled_mask = self.sampler.subsample_batch(
            self.rpn_batch_size,
            pos_indicator,
            criterion=cls_criterion,
            indicator=indicator)
        batch_sampled_mask = batch_sampled_mask.type_as(rpn_cls_weights)
        rpn_cls_weights = rpn_cls_weights * batch_sampled_mask
        rpn_reg_weights = rpn_reg_weights * batch_sampled_mask
        num_cls_coeff = rpn_cls_weights.type(torch.cuda.ByteTensor).sum(dim=1)
        num_reg_coeff = rpn_reg_weights.type(torch.cuda.ByteTensor).sum(dim=1)
        # check
        #  assert num_cls_coeff, 'bug happens'
        #  assert num_reg_coeff, 'bug happens'
        if num_cls_coeff == 0:
            num_cls_coeff = torch.ones([]).type_as(num_cls_coeff)
        if num_reg_coeff == 0:
            num_reg_coeff = torch.ones([]).type_as(num_reg_coeff)

        # cls loss
        rpn_cls_score = prediction_dict['rpn_cls_scores']
        # rpn_cls_loss = self.rpn_cls_loss(rpn_cls_score, rpn_cls_targets)
        rpn_cls_loss = self.rpn_cls_loss(rpn_cls_score.view(-1, 2),
                                         rpn_cls_targets.view(-1))
        rpn_cls_loss = rpn_cls_loss.view_as(rpn_cls_weights)
        rpn_cls_loss *= rpn_cls_weights
        rpn_cls_loss = rpn_cls_loss.sum(dim=1) / num_cls_coeff.float()

        # bbox loss
        # shape(N,num,4)
        rpn_bbox_preds = prediction_dict['rpn_bbox_preds']
        rpn_bbox_preds = rpn_bbox_preds.permute(0, 2, 3, 1).contiguous()
        # shape(N,H*W*num_anchors,4)
        rpn_bbox_preds = rpn_bbox_preds.view(rpn_bbox_preds.shape[0], -1, 4)
        rpn_reg_loss = self.rpn_bbox_loss(rpn_bbox_preds, rpn_reg_targets)
        rpn_reg_loss *= rpn_reg_weights.unsqueeze(-1).expand(-1, -1, 4)
        rpn_reg_loss = rpn_reg_loss.view(
            rpn_reg_loss.shape[0], -1).sum(dim=1) / num_reg_coeff.float()

        loss_dict['rpn_cls_loss'] = rpn_cls_loss
        loss_dict['rpn_bbox_loss'] = rpn_reg_loss
        return loss_dict
예제 #10
0
class SINetModel(Model):
    def collect_intermedia_layers(self, img):
        feat2 = self.feature_extractor.first_stage_feature[:-1](img)
        feat3 = self.feature_extractor.first_stage_feature[-1](feat2)

        end_points = {'feat2': feat2, 'feat3': feat3}
        return feat3, end_points

    def caroi_pooling(self, all_feats, rois_batch, out_channels):
        pooled_feats = []
        for feat in all_feats:
            pooled_feats.append(self.rcnn_pooling(feat, rois_batch))
        pooled_feats = torch.cat(pooled_feats, dim=1)
        if pooled_feats.shape[1] != out_channels:
            # add 1x1 conv
            pooled_feats = self.reduce_pooling(pooled_feats)
        return pooled_feats

    def forward(self, feed_dict):

        prediction_dict = {}

        # base model
        # base_feat = self.feature_extractor.first_stage_feature(
        # feed_dict['img'])
        base_feat, all_feats = self.collect_intermedia_layers(feed_dict['img'])
        feed_dict.update({'base_feat': base_feat})
        self.add_feat('base_feat', base_feat)

        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))

        # proposals = prediction_dict['proposals_batch']
        # shape(N,num_proposals,5)
        # pre subsample for reduce consume of memory
        if self.training:
            self.pre_subsample(prediction_dict, feed_dict)
        rois_batch = prediction_dict['rois_batch']

        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        # pooled_feat = self.rcnn_pooling(base_feat, rois_batch.view(-1, 5))
        # import ipdb
        # ipdb.set_trace()
        pooled_feat = self.caroi_pooling(
            all_feats, rois_batch.view(-1, 5), out_channels=1024)

        # shape(N,C,1,1)
        pooled_feat = self.feature_extractor.second_stage_feature(pooled_feat)
        # shape(N,C)
        if self.reduce:
            pooled_feat = pooled_feat.mean(3).mean(2)
        else:
            pooled_feat = pooled_feat.view(self.rcnn_batch_size, -1)

        rcnn_bbox_preds = self.rcnn_bbox_pred(pooled_feat)
        rcnn_cls_scores = self.rcnn_cls_pred(pooled_feat)

        rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)

        prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds
        prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores

        # used for track
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][
            proposals_order]

        return prediction_dict

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)

    def init_modules(self):
        self.feature_extractor = ResNetFeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        self.rcnn_cls_pred = nn.Linear(2048, self.n_classes)
        if self.reduce:
            in_channels = 2048
        else:
            in_channels = 2048 * 4 * 4
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4)
        else:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2)
        else:
            self.rcnn_cls_loss = functools.partial(
                F.cross_entropy, reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

        self.rcnn_pooling2 = RoIAlignAvg(self.pooling_size, self.pooling_size,
                                         1.0 / 8.0)
        self.reduce_pooling = nn.Sequential(
            nn.Conv2d(1024 + 512, 1024, 1, 1, 0), nn.ReLU())

    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']

        # some submodule config
        self.feature_extractor_config = model_config['feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])

        # self.reduce = model_config.get('reduce')
        self.reduce = True

        #  self.visualizer = FeatVisualizer()

    def pre_subsample(self, prediction_dict, feed_dict):
        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        cls_criterion = None
        pos_indicator = rcnn_reg_weights > 0
        indicator = rcnn_cls_weights > 0

        # subsample from all
        # shape (N,M)
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        # check
        assert num_cls_coeff, 'bug happens'
        assert num_reg_coeff, 'bug happens'

        prediction_dict[
            'rcnn_cls_weights'] = rcnn_cls_weights / num_cls_coeff.float()
        prediction_dict[
            'rcnn_reg_weights'] = rcnn_reg_weights / num_reg_coeff.float()
        prediction_dict['rcnn_cls_targets'] = rcnn_cls_targets[
            batch_sampled_mask]
        prediction_dict['rcnn_reg_targets'] = rcnn_reg_targets[
            batch_sampled_mask]
        prediction_dict['fake_match'] = self.target_assigner.analyzer.match[
            batch_sampled_mask]

        # update rois_batch
        prediction_dict['rois_batch'] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)

        if not self.training:
            # used for track
            proposals_order = prediction_dict['proposals_order']

            prediction_dict['proposals_order'] = proposals_order[
                batch_sampled_mask]

    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        # targets and weights
        rcnn_cls_weights = prediction_dict['rcnn_cls_weights']
        rcnn_reg_weights = prediction_dict['rcnn_reg_weights']

        rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
        rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

        # classification loss
        rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
        rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores, rcnn_cls_targets)
        rcnn_cls_loss *= rcnn_cls_weights
        rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

        # bounding box regression L1 loss
        rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                             rcnn_reg_targets).sum(dim=-1)
        rcnn_bbox_loss *= rcnn_reg_weights
        # rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

        # loss weights has no gradients
        loss_dict['rcnn_cls_loss'] = rcnn_cls_loss
        loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        # add rcnn_cls_targets to get the statics of rpn
        # loss_dict['rcnn_cls_targets'] = rcnn_cls_targets

        # analysis ap
        rcnn_cls_probs = prediction_dict['rcnn_cls_probs']
        num_gt = feed_dict['gt_labels'].numel()
        fake_match = prediction_dict['fake_match']
        self.target_assigner.analyzer.analyze_ap(
            fake_match, rcnn_cls_probs[:, 1], num_gt, thresh=0.5)

        return loss_dict
예제 #11
0
class AnchorTargetLayer(nn.Module):
    """
    """
    def __init__(self, layer_config):
        super().__init__()
        # some parameters
        self.rpn_positive_weight = layer_config['rpn_positive_weight']
        self.rpn_negative_overlaps = layer_config['rpn_negative_overlaps']
        self.rpn_positive_overlaps = layer_config['rpn_positive_overlaps']
        self.rpn_batch_size = layer_config['rpn_batch_size']
        # subsample score and iou or subsample score only
        self.subsample_twice = layer_config['subsample_twice']
        self.subsample_type = layer_config['subsample_type']

        self.target_assigner = TargetAssigner()
        self.sampler = Sampler(self.subsample_type)

    def forward(self, anchors, rpn_cls_score, gt_boxes, gt_labels):
        """
        Subsample and generate samples for training
        Args:
            rpn_cls_score, used for subsample
            gt_boxes, gt boxes,shape(N,M,4)
            anchors, shape(K,4)
            im_info, info of image size and ratios
        Returns:
            bbox_weights: weights for box regression
            cls_weights: wegihts for cls
            labels: labels for each anchors
            bbox_targets: bbox regression target for each anchors
        """
        ######################
        # assignments
        ######################
        cls_targets, reg_targets, cls_weights, reg_weights = self.target_assigner.assign(
            anchors, gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        if self.subsample_twice:
            # subsample both
            cls_batch_sampled_mask = self.sampler.subsample(
                cls_weights,
                self.rpn_batch_size,
                cls_targets.type(torch.ByteTensor),
                critation=rpn_cls_score)
            cls_weights *= cls_batch_sampled_mask
            reg_batch_sampled_mask = self.sampler.subsample(
                reg_weights, self.rpn_batch_size)
            reg_weights *= reg_batch_sampled_mask
        else:
            # subsample score only
            batch_sampled_mask = self.sampler.subsample(
                cls_weights,
                self.rpn_batch_size,
                cls_targets.type(torch.ByteTensor),
                critation=rpn_cls_score)
            cls_weights = cls_weights * batch_sampled_mask
            reg_weights = reg_weights * batch_sampled_mask

        output = {}
        output['cls_targets'] = cls_targets
        output['reg_targets'] = reg_targets
        output['cls_weights'] = cls_weights
        output['reg_weights'] = reg_weights

        return output
예제 #12
0
class SSDModel(Model):
    def init_params(self, model_config):
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.multibox_cfg = model_config['multibox_cfg']

        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

    def init_module(self):
        self.feature_extractor = PyramidVggnetExtractor(
            self.feature_extractor_config)

        # loc layers and conf layers
        base_feat = self.feature_extractor.base_feat
        extra_layers = self.feature_extractor.extras_layers
        loc_layers, conf_layers = self.make_multibox(base_feat, extra_layers)
        self.loc_layers = loc_layers
        self.conf_layers = conf_layers

        # loss layers
        self.loc_loss = F.smooth_l1_loss()
        self.conf_loss = nn.CrossEntropyLoss(reduce=False)

    def make_multibox(self, vgg, extra_layers):
        cfg = self.multibox_cfg
        num_classes = self.n_classes
        loc_layers = []
        conf_layers = []
        vgg_source = [21, -2]
        for k, v in enumerate(vgg_source):
            loc_layers += [
                nn.Conv2d(vgg[v].out_channels,
                          cfg[k] * 4,
                          kernel_size=3,
                          padding=1)
            ]
            conf_layers += [
                nn.Conv2d(vgg[v].out_channels,
                          cfg[k] * num_classes,
                          kernel_size=3,
                          padding=1)
            ]
        for k, v in enumerate(extra_layers[1::2], 2):
            loc_layers += [
                nn.Conv2d(v.out_channels, cfg[k] * 4, kernel_size=3, padding=1)
            ]
            conf_layers += [
                nn.Conv2d(v.out_channels,
                          cfg[k] * num_classes,
                          kernel_size=3,
                          padding=1)
            ]
        return loc_layers, conf_layers

    def init_weights(self):
        pass

    def forward(self, feed_dict):
        img = feed_dict['img']
        source_feats = self.feature_extractor(img)
        loc_preds = []
        conf_preds = []

        # apply multibox head to source layers
        for (x, l, c) in zip(source_feats, self.loc_layers, self.conf_layers):
            loc_preds.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf_preds.append(c(x).permute(0, 2, 3, 1).contiguous())

        loc_preds = torch.cat([o.view(o.size(0), -1) for o in loc_preds], 1)
        conf_preds = torch.cat([o.view(o.size(0), -1) for o in conf_preds], 1)
        prediction_dict = {'loc_preds': loc_preds, 'conf_preds': conf_preds}
        return prediction_dict

    def loss(self, prediction_dict, feed_dict):
        loc_targets, conf_targets = self.target_assigner.assign()

        # ohem
        batch_sampled_mask = self.sampler.subsample()

        loc_preds = prediction_dict['loc_preds']
        # loc loss
        loc_loss = self.loc_loss(loc_preds, loc_targets)

        conf_preds = prediction_dict['conf_preds']
        # conf loss
        conf_loss = self.conf_loss(conf_preds, conf_targets)

        loss_dict = {'loc_loss': loc_loss, 'conf_loss': conf_loss}
        return loss_dict
class DoubleIoUSecondStageFasterRCNN(Model):
    def forward(self, feed_dict):
        # import ipdb
        # ipdb.set_trace()
        # self.visualizer.visualize(
        # feed_dict['img'],
        # nn.Sequential(self.feature_extractor.first_stage_feature,
        # self.feature_extractor.first_stage_cls_feature))

        prediction_dict = {}

        # base model
        base_feat = self.feature_extractor.first_stage_feature(
            feed_dict['img'])
        feed_dict.update({'base_feat': base_feat})
        self.add_feat('base_feat', base_feat)

        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))

        # shape(N,num_proposals,5)
        # pre subsample for reduce consume of memory
        if self.training:
            self.pre_subsample(prediction_dict, feed_dict)
        rois_batch = prediction_dict['rois_batch']

        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        pooled_feat = self.rcnn_pooling(base_feat, rois_batch.view(-1, 5))
        pooled_feat = F.relu(self.rcnn_conv(pooled_feat), inplace=True)

        pooled_feat_cls = self.rcnn_pooled_feat_cls(pooled_feat.detach())
        pooled_feat_bbox = self.rcnn_pooled_feat_bbox(pooled_feat)

        #  classification
        pooled_feat_cls = self.feature_extractor.third_stage_feature(
            pooled_feat_cls)
        pooled_feat_cls = pooled_feat_cls.mean(3).mean(2)
        rcnn_cls_scores = self.rcnn_cls_pred(pooled_feat_cls)

        rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)

        # regression
        pooled_feat_reg = self.feature_extractor.second_stage_feature(
            pooled_feat_bbox)
        pooled_feat_reg = pooled_feat_reg.mean(3).mean(2)

        rcnn_bbox_preds = self.rcnn_bbox_pred(pooled_feat_reg)

        prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds
        prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores

        # used for track
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][
            proposals_order]

        return prediction_dict

    def unfreeze_part_modules(self, model):
        #  model = self.feature_extractor.third_stage_feature
        for param in model.parameters():
            param.requires_grad = True

        #  model = self.feature_extractor.first_stage_feature

        # def freeze_part_modules(self):
        # pass

        # def rcnn_cls_pred(pooled_feat)

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)
        # if self.training_stage == 'cls':
        # self.freeze_modules()
        # unfreeze part
        # models = [

    # #  self.feature_extractor.first_stage_feature,
    # #  self.feature_extractor.second_stage_feature,
    # self.feature_extractor.third_stage_feature
    # ]
    # for model in models:
    # self.unfreeze_part_modules(model)

    def init_modules(self):
        self.feature_extractor = ResNetFeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        self.rcnn_cls_pred = nn.Linear(2048, self.n_classes)
        if self.reduce:
            in_channels = 2048
        else:
            in_channels = 2048 * 4 * 4
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4)
        else:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2)
        else:
            self.rcnn_cls_loss = functools.partial(F.cross_entropy,
                                                   reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

        # decouple cls and bbox
        self.rcnn_conv = nn.Conv2d(1024, 512, 3, 1, 1, bias=True)
        self.rcnn_pooled_feat_cls = nn.Conv2d(512, 1024, 1, 1, 0)
        self.rcnn_pooled_feat_bbox = nn.Conv2d(512, 1024, 1, 1, 0)

    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']

        # some submodule config
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])
        #  self.sampler = DetectionSampler({'fg_fraction': 1})

        # self.reduce = model_config.get('reduce')
        self.reduce = True
        self.visualizer = FeatVisualizer()

    def pre_subsample(self, prediction_dict, feed_dict):
        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        cls_criterion = None
        pos_indicator = rcnn_reg_weights > 0
        indicator = rcnn_cls_weights > 0
        #  indicator = None

        # subsample from all
        # shape (N,M)
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        # check
        assert num_cls_coeff, 'bug happens'
        assert num_reg_coeff, 'bug happens'

        prediction_dict[
            'rcnn_cls_weights'] = rcnn_cls_weights / num_cls_coeff.float()
        prediction_dict[
            'rcnn_reg_weights'] = rcnn_reg_weights / num_reg_coeff.float()
        prediction_dict['rcnn_cls_targets'] = rcnn_cls_targets[
            batch_sampled_mask]
        prediction_dict['fake_match'] = self.target_assigner.analyzer.match[
            batch_sampled_mask]
        prediction_dict['rcnn_reg_targets'] = rcnn_reg_targets[
            batch_sampled_mask]

        # update rois_batch
        prediction_dict['rois_batch'] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)

        if not self.training:
            # used for track
            proposals_order = prediction_dict['proposals_order']

            prediction_dict['proposals_order'] = proposals_order[
                batch_sampled_mask]

    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        # targets and weights
        rcnn_cls_weights = prediction_dict['rcnn_cls_weights']
        rcnn_reg_weights = prediction_dict['rcnn_reg_weights']

        rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
        rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

        # classification loss
        rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
        rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores, rcnn_cls_targets)
        rcnn_cls_loss *= rcnn_cls_weights
        rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

        # bounding box regression L1 loss
        rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                             rcnn_reg_targets).sum(dim=-1)
        rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

        # loss weights has no gradients
        loss_dict['rcnn_cls_loss'] = rcnn_cls_loss
        loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        # add rcnn_cls_targets to get the statics of rpn
        # loss_dict['rcnn_cls_targets'] = rcnn_cls_targets
        rcnn_cls_probs = prediction_dict['rcnn_cls_probs']
        #  fake_match = self.target_assigner.analyzer.match
        fake_match = prediction_dict['fake_match']
        num_gt = feed_dict['gt_labels'].numel()
        self.target_assigner.analyzer.analyze_ap(fake_match,
                                                 rcnn_cls_probs[:, 1],
                                                 num_gt,
                                                 thresh=0.5)
        #  prediction_dict['rcnn_reg_weights'] = rcnn_reg_weights

        return loss_dict

    def loss_new(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        # import ipdb
        # ipdb.set_trace()
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        # cls_criterion = None
        # pos_indicator = rcnn_reg_weights > 0
        indicator = rcnn_cls_weights > 0
        pos_indicator = indicator
        # indicator = None

        rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
        rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores,
                                           rcnn_cls_targets[0])

        # bounding box regression L1 loss
        rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                             rcnn_reg_targets[0]).sum(dim=-1)

        cls_criterion = rcnn_cls_loss * rcnn_cls_weights + rcnn_bbox_loss * rcnn_reg_weights
        # subsample from all
        # shape (N,M)
        # import ipdb
        # ipdb.set_trace()
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights * batch_sampled_mask.type_as(
            rcnn_cls_weights)
        num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
        # check
        assert num_cls_coeff, 'bug happens'

        rcnn_cls_weights = rcnn_cls_weights / num_cls_coeff.float()

        # import ipdb
        # ipdb.set_trace()
        # rcnn_cls_targets *= batch_sampled_mask.type_as(rcnn_cls_targets)
        # rcnn_reg_targets *= batch_sampled_mask.type_as(rcnn_reg_targets)

        # targets and weights
        # rcnn_cls_weights = prediction_dict['rcnn_cls_weights']
        # rcnn_reg_weights = prediction_dict['rcnn_reg_weights']

        # rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
        # rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

        # classification loss

        # import ipdb
        # ipdb.set_trace()
        rcnn_cls_loss *= rcnn_cls_weights[0]
        rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

        # bbox reg
        rcnn_reg_weights *= batch_sampled_mask.type_as(rcnn_reg_weights)
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        assert num_reg_coeff, 'bug happens'
        rcnn_reg_weights = rcnn_reg_weights / num_reg_coeff.float()

        rcnn_bbox_loss *= rcnn_reg_weights[0]
        # rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

        # # loss weights has no gradients
        loss_dict['rcnn_cls_loss'] = rcnn_cls_loss
        # loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        # add rcnn_cls_targets to get the statics of rpn
        # loss_dict['rcnn_cls_targets'] = rcnn_cls_targets

        # analysis precision
        rcnn_cls_probs = prediction_dict['rcnn_cls_probs']
        fake_match = self.target_assigner.analyzer.match
        num_gt = feed_dict['gt_labels'].numel()
        self.target_assigner.analyzer.analyze_ap(fake_match,
                                                 rcnn_cls_probs[:, 1],
                                                 num_gt,
                                                 thresh=0.5)
        prediction_dict['rcnn_reg_weights'] = rcnn_reg_weights
        return loss_dict
예제 #14
0
class RefineOFTModel(Model):
    def _pad_or_crop(self, feed_dict):
        # import ipdb
        # ipdb.set_trace()
        img = feed_dict['img']
        img_shape = img.shape[-2:]
        target_shape = (1, 3, 384, 1280)
        new_image = torch.zeros(target_shape).type_as(img)
        new_image[:, :, :img_shape[0], :img_shape[1]] = img

        feed_dict['img'] = new_image

    def forward(self, feed_dict):
        self._pad_or_crop(feed_dict)
        # import ipdb
        # ipdb.set_trace()

        self.profiler.start('1')
        self.voxel_generator.proj_voxels_3dTo2d(feed_dict['p2'],
                                                feed_dict['im_info'])
        self.profiler.end('1')

        self.profiler.start('2')
        img_feat_maps = self.feature_extractor.forward(feed_dict['img'])
        self.profiler.end('2')

        self.profiler.start('3')
        img_feat_maps = self.feature_preprocess(img_feat_maps)
        self.profiler.end('3')

        # early fusion in image level features
        # import ipdb
        # ipdb.set_trace()
        img_feat_maps = self.img_feat_fusion(img_feat_maps)

        self.profiler.start('4')
        integral_maps = self.generate_integral_maps(img_feat_maps)
        self.profiler.end('4')

        # import ipdb
        # ipdb.set_trace()
        self.profiler.start('5')
        oft_maps = self.generate_oft_maps(integral_maps)
        self.profiler.end('5')

        self.profiler.start('6')
        bev_feat_maps = self.feature_extractor.bev_feature(oft_maps)
        self.profiler.end('6')

        # pred output
        # shape (NCHW)
        self.profiler.start('7')
        rpn_output_maps = self.rpn_output_head(bev_feat_maps)

        voxel_centers = self.voxel_generator.voxel_centers
        D = self.voxel_generator.lattice_dims[1]
        voxel_centers = voxel_centers.view(-1, D, 3)[:, 0, :]

        rpn_output = rpn_output_maps.permute(0, 2, 3, 1).contiguous().view(
            self.batch_size, -1, self.rpn_output_channels)

        #####################################################
        # decode output of first stage to crop image features
        #####################################################

        rpn_output = rpn_output.detach()
        rpn_bbox_preds = rpn_output[:, :, self.n_classes:]
        rpn_pred_scores = rpn_output[:, :, :self.n_classes]

        rpn_bbox_3d = self.bbox_coder.decode_batch_bbox(
            voxel_centers, rpn_bbox_preds)
        rpn_cls_probs = F.softmax(rpn_pred_scores, dim=-1)
        fg_rpn_cls_probs = rpn_cls_probs[:, :, -1]
        order = torch.sort(fg_rpn_cls_probs, descending=True)[1]
        # topn = 2000
        # order = order[:, :topn]
        rpn_cls_probs = fg_rpn_cls_probs[0][order[0]].unsqueeze(0)
        rpn_bbox_3d = rpn_bbox_3d[0][order[0]].unsqueeze(0)

        fake_ry = torch.zeros_like(rpn_bbox_3d[0, :, 0])

        target = {
            'dimension': rpn_bbox_3d[0, :, :3],
            'location': rpn_bbox_3d[0, :, 3:6],
            'ry': fake_ry
        }
        boxes_2d = Projector.proj_box_3to2img(target, feed_dict['p2'])
        rois_idx = torch.zeros_like(boxes_2d[:, -1:])
        rois_2d = torch.cat([rois_idx, boxes_2d], dim=-1)

        rcnn_img_feat_maps = self.rcnn_pooling(img_feat_maps[0], rois_2d)

        # should do something for maps
        # import ipdb
        # ipdb.set_trace()
        rcnn_img_feat_maps = self.feature_extractor.img_feat_extractor(
            rcnn_img_feat_maps)

        rcnn_img_feat_maps = rcnn_img_feat_maps.mean(dim=-2).mean(dim=-1)

        self.profiler.end('7')

        ###############################
        # second stage
        ###############################

        rcnn_img_feat_maps = rcnn_img_feat_maps.permute(
            1, 0).unsqueeze(0).contiguous().view(self.batch_size, -1,
                                                 *(bev_feat_maps.shape[-2:]))

        rcnn_output_maps = torch.cat([rcnn_img_feat_maps, bev_feat_maps],
                                     dim=1)
        output_maps = self.rcnn_output_head(rcnn_output_maps)

        # shape(N,M,out_channels)
        pred_3d = output_maps.permute(0, 2, 3, 1).contiguous().view(
            self.batch_size, -1, self.rcnn_output_channels)

        pred_boxes_3d = pred_3d[:, :, self.n_classes:]
        pred_scores_3d = pred_3d[:, :, :self.n_classes]

        pred_probs_3d = F.softmax(pred_scores_3d, dim=-1)
        # import ipdb
        # ipdb.set_trace()
        self.add_feat('pred_scores_3d', output_maps[:, 1:2, :, :])
        self.add_feat('bev_feat_map', bev_feat_maps)

        if not self.training:
            # import ipdb
            # ipdb.set_trace()

            # pred_boxes_3d = self.bbox_coder.decode_batch_bbox(voxel_centers,
            # pred_boxes_3d)
            # decode angle
            angles_oritations = self.bbox_coder.decode_batch_angle(
                pred_boxes_3d[:, :, 6:], self.angle_loss.bin_centers,
                self.num_bins)

            pred_boxes_3d = self.bbox_coder.decode_batch_bbox(
                voxel_centers, pred_boxes_3d[:, :, :6])
            # import ipdb
            # ipdb.set_trace()
            # random_value = torch.rand(angles_oritations.shape)
            # angles_oritations = random_value.type_as(
            # angles_oritations) * angles_oritations

            pred_boxes_3d = torch.cat([pred_boxes_3d, angles_oritations],
                                      dim=-1)

            # gussian filter probs map
            # reshape first
            shape = output_maps.shape[-2:]
            fg_mask = pred_probs_3d[0, :, 1].view(shape).detach().cpu().numpy()

            # then smooth
            from scipy.ndimage import gaussian_filter
            smoothed_fg_mask = gaussian_filter(fg_mask, sigma=self.nms_deltas)

            smoothed_fg_mask = torch.tensor(smoothed_fg_mask).type_as(
                pred_probs_3d)

            # nms
            smoothed_fg_mask = self.nms_map(smoothed_fg_mask)

            # assign back to tensor
            pred_probs_3d[0, :, 1] = smoothed_fg_mask.view(-1)

            # reset bg according to fg
            pred_probs_3d[0, :, 0] = 1 - pred_probs_3d[0, :, 1]

        prediction_dict = {}
        prediction_dict['pred_boxes_3d'] = pred_boxes_3d
        # prediction_dict['pred_scores_3d'] = pred_scores_3d
        prediction_dict['pred_probs_3d'] = pred_probs_3d

        prediction_dict['rpn_boxes_3d'] = rpn_output[:, :, self.n_classes:]
        prediction_dict['rpn_probs_preds'] = rpn_output[:, :, :self.n_classes]

        return prediction_dict

    def generate_proposals(self):
        pass

    def img_feat_fusion(self, img_feat_maps):
        # import ipdb
        # ipdb.set_trace()
        upconv3 = self.upconv3(img_feat_maps[2])
        upconv3 = self.upconv3_bn(upconv3)
        upconv3 = self.upconv3_relu(upconv3)

        sum2 = torch.cat([upconv3, img_feat_maps[1]], dim=1)
        fusion2 = self.fusion2(sum2)
        fusion2 = self.fusion2_bn(fusion2)
        fusion2 = self.relu2(fusion2)

        upconv2 = self.upconv2(img_feat_maps[1])
        upconv2 = self.upconv2_bn(upconv2)
        upconv2 = self.upconv2_relu(upconv2)

        sum1 = torch.cat([upconv2, img_feat_maps[0]], dim=1)
        fusion1 = self.fusion1(sum1)
        fusion1 = self.fusion1_bn(fusion1)
        fusion1 = self.relu1(fusion1)

        # import ipdb
        # ipdb.set_trace()
        # just return the finest map
        return [fusion1]

    def nms_map(self, smoothed_fg_mask):
        """
        supress the neibor
        """

        directions = [-1, 0, 1]
        shape = smoothed_fg_mask.shape
        orig_index = (torch.arange(shape[0]).cuda().long(),
                      torch.arange(shape[1]).cuda().long())
        orig_index = ops.meshgrid(orig_index[1], orig_index[0])
        orig_index = [orig_index[1], orig_index[0]]
        dest_indexes = []
        for i in directions:
            for j in directions:
                dest_index = (orig_index[0] + directions[i],
                              orig_index[1] + directions[j])
                dest_indexes.append(dest_index)

        nms_filter = torch.ones_like(smoothed_fg_mask).byte()
        orig_fg_mask = smoothed_fg_mask

        # pad fg mask first to prevent out of boundary
        padded_smoothed_fg_mask = torch.zeros(
            (shape[0] + 1, shape[1] + 1)).type_as(smoothed_fg_mask)
        padded_smoothed_fg_mask[:-1, :-1] = smoothed_fg_mask

        # import ipdb
        # ipdb.set_trace()
        for dest_index in dest_indexes:
            nms_filter = nms_filter & (
                orig_fg_mask >=
                padded_smoothed_fg_mask[dest_index].view_as(orig_fg_mask))

        # surpress
        smoothed_fg_mask[~nms_filter] = 0
        return smoothed_fg_mask

    def feature_preprocess(self, feat_maps):
        # import ipdb
        # ipdb.set_trace()
        reduced_feat_maps = []
        for ind, feat_map in enumerate(feat_maps):
            reduced_feat_map = self.feats_reduces[ind](feat_map)
            reduced_feat_maps.append(reduced_feat_map)
        return reduced_feat_maps

    def generate_integral_maps(self, img_feat_maps):
        integral_maps = []
        for img_feat_map in img_feat_maps:
            integral_maps.append(
                self.integral_map_generator.generate(img_feat_map))

        return integral_maps

    def generate_oft_maps(self, integral_maps):
        # shape(N,4)
        normalized_voxel_proj_2d = self.voxel_generator.normalized_voxel_proj_2d
        # for i in range(voxel_proj_2d.shape[0]):
        multiscale_img_feat = []
        for integral_map in integral_maps:
            multiscale_img_feat.append(
                self.integral_map_generator.calc(integral_map,
                                                 normalized_voxel_proj_2d))

        # shape(N,C,HWD)
        # only one image
        fusion_feat = multiscale_img_feat[0]
        depth_dim = self.voxel_generator.lattice_dims[1]
        height_dim = self.voxel_generator.lattice_dims[0]

        fusion_feat = fusion_feat.view(self.batch_size, self.feat_size, -1,
                                       depth_dim).permute(0, 3, 1,
                                                          2).contiguous()
        # shape(N,C,HW)
        oft_maps = self.feat_collapse(fusion_feat).view(
            self.batch_size, self.feat_size, height_dim, -1)

        return oft_maps

    def init_param(self, model_config):

        self.feat_size = model_config['common_feat_size']
        self.batch_size = model_config['batch_size']
        self.sample_size = model_config['sample_size']
        self.pooling_size = model_config['pooling_size']
        self.n_classes = model_config['num_classes']
        self.use_focal_loss = model_config['use_focal_loss']
        self.feature_extractor_config = model_config[
            'feature_extractor_config']

        self.voxel_generator = VoxelGenerator(
            model_config['voxel_generator_config'])
        self.voxel_generator.init_voxels()

        self.integral_map_generator = IntegralMapGenerator()

        self.oft_target_assigner = OFTargetAssigner(
            model_config['target_assigner_config'])

        self.target_assigner = TargetAssigner(
            model_config['eval_target_assigner_config'])
        self.target_assigner.analyzer.append_gt = False

        self.sampler = DetectionSampler(model_config['sampler_config'])

        self.bbox_coder = self.oft_target_assigner.bbox_coder

        # find the most expensive operators
        self.profiler = Profiler()

        # self.multibin = model_config['multibin']
        self.num_bins = model_config['num_bins']

        self.reg_channels = 3 + 3 + self.num_bins * 4

        # score, pos, dim, ang
        self.rcnn_output_channels = self.n_classes + self.reg_channels

        self.rpn_output_channels = 2 + 3 + 3

        nms_deltas = model_config.get('nms_deltas')
        if nms_deltas is None:
            nms_deltas = 1
        self.nms_deltas = nms_deltas

    def init_modules(self):
        """
        some modules
        """

        self.feature_extractor = OFTNetFeatureExtractor(
            self.feature_extractor_config)

        feats_reduce_1 = nn.Conv2d(128, self.feat_size, 1, 1, 0)
        feats_reduce_2 = nn.Conv2d(256, self.feat_size, 1, 1, 0)
        feats_reduce_3 = nn.Conv2d(512, self.feat_size, 1, 1, 0)
        self.feats_reduces = nn.ModuleList(
            [feats_reduce_1, feats_reduce_2, feats_reduce_3])

        self.feat_collapse = nn.Conv2d(8, 1, 1, 1, 0)

        self.rcnn_output_head = nn.Conv2d(1152, self.rcnn_output_channels, 1,
                                          1, 0)

        self.rpn_output_head = nn.Conv2d(256 * 4, self.rpn_output_channels, 1,
                                         1, 0)

        # loss
        self.reg_loss = nn.L1Loss(reduce=False)
        # self.reg_loss = nn.SmoothL1Loss(reduce=False)
        # if self.use_focal_loss:
        # self.conf_loss = FocalLoss(
        # self.n_classes, alpha=0.2, gamma=2, auto_alpha=False)
        # else:
        # self.conf_loss = nn.CrossEntropyLoss(reduce=False)
        self.conf_loss = nn.L1Loss(reduce=False)

        self.angle_loss = MultiBinLoss(num_bins=self.num_bins)

        # fusion layer
        # self.upconv1 = nn.ConvTranspose2d(self.feat_size, self.feat_size, 2, 2,
        # 0)
        self.fusion1 = nn.Conv2d(2 * self.feat_size, self.feat_size, 3, 1, 1)
        self.fusion1_bn = nn.BatchNorm2d(self.feat_size)
        self.relu1 = nn.ReLU()
        self.upconv2 = nn.ConvTranspose2d(self.feat_size, self.feat_size, 2, 2,
                                          0)
        self.upconv2_bn = nn.BatchNorm2d(self.feat_size)
        self.upconv2_relu = nn.ReLU()

        self.relu2 = nn.ReLU()
        self.fusion2 = nn.Conv2d(2 * self.feat_size, self.feat_size, 3, 1, 1)
        self.fusion2_bn = nn.BatchNorm2d(self.feat_size)

        self.upconv3 = nn.ConvTranspose2d(self.feat_size, self.feat_size, 2, 2,
                                          0)
        self.upconv3_bn = nn.BatchNorm2d(self.feat_size)
        self.upconv3_relu = nn.ReLU()
        # self.fusion3 = nn.Conv2d(self.feat_size, self.feat_size, 3, 1, 1)
        # self.relu3 = nn.ReLU()

        self.rcnn_pooling = RoIAlignAvg(self.pooling_size, self.pooling_size,
                                        1.0 / 8.0)

    def init_weights(self):
        self.feature_extractor.init_weights()

    def rpn_loss(self, preds, targets, weights):
        rpn_cls_probs = preds['pred_probs_3d'][:, :, -1]
        rpn_bbox_preds = preds['pred_boxes_3d']
        cls_targets = targets['cls_targets']
        cls_weights = weights['cls_weights']
        reg_weights = weights['reg_weights']
        reg_targets = targets['reg_targets']

        # cls loss
        rpn_cls_loss = self.conf_loss(rpn_cls_probs, cls_targets)
        rpn_cls_loss = rpn_cls_loss.view_as(cls_weights)
        rpn_cls_loss = rpn_cls_loss * cls_weights
        rpn_cls_loss = rpn_cls_loss.mean(dim=-1)

        # bbox loss
        rpn_reg_loss = self.reg_loss(rpn_bbox_preds[:, :, :6],
                                     reg_targets[:, :, :-1])
        rpn_reg_loss = rpn_reg_loss * reg_weights.unsqueeze(-1)
        num_reg_coeff = (reg_weights > 0).sum(dim=-1)
        num_reg_coeff = num_reg_coeff.type_as(reg_weights)
        rpn_reg_loss = rpn_reg_loss.sum(dim=-1).sum(dim=-1) / num_reg_coeff

        # angle_loss
        # angle_loss, angle_tp_mask = self.angle_loss(rpn_bbox_preds[:, :, 6:],
        # reg_targets[:, :, -1:])
        # rpn_angle_loss = angle_loss * reg_weights
        return {
            # 'rpn_angle_loss': rpn_angle_loss,
            'rpn_cls_loss': rpn_cls_loss,
            'rpn_reg_loss': rpn_reg_loss
        }

    def loss(self, prediction_dict, feed_dict):
        self.profiler.start('8')
        gt_boxes_3d = feed_dict['gt_boxes_3d']
        gt_labels = feed_dict['gt_labels']
        gt_boxes_ground_2d_rect = feed_dict['gt_boxes_ground_2d_rect']

        voxels_ground_2d = self.voxel_generator.proj_voxels_to_ground()
        voxel_centers = self.voxel_generator.voxel_centers
        D = self.voxel_generator.lattice_dims[1]
        voxel_centers = voxel_centers.view(-1, D, 3)[:, 0, :]

        # gt_boxes_3d = torch.cat([gt_boxes_3d[:,:,:3],],dim=-1)

        cls_weights, reg_weights, cls_targets, reg_targets = self.oft_target_assigner.assign(
            voxels_ground_2d, gt_boxes_ground_2d_rect, voxel_centers,
            gt_boxes_3d, gt_labels)

        num_reg_coeff = (reg_weights > 0).sum(dim=-1)

        if num_reg_coeff == 0:
            num_reg_coeff = torch.ones([]).type_as(num_reg_coeff)

        # cls loss
        rpn_cls_probs = prediction_dict['pred_probs_3d'][:, :, -1]
        rpn_cls_loss = self.conf_loss(rpn_cls_probs, cls_targets)
        rpn_cls_loss = rpn_cls_loss.view_as(cls_weights)
        rpn_cls_loss = rpn_cls_loss * cls_weights
        rpn_cls_loss = rpn_cls_loss.mean(dim=-1)

        # bbox loss
        rpn_bbox_preds = prediction_dict['pred_boxes_3d']
        rpn_reg_loss = self.reg_loss(rpn_bbox_preds[:, :, :6],
                                     reg_targets[:, :, :-1])
        rpn_reg_loss = rpn_reg_loss * reg_weights.unsqueeze(-1)
        num_reg_coeff = num_reg_coeff.type_as(reg_weights)

        # angle_loss
        angle_loss, angle_tp_mask = self.angle_loss(rpn_bbox_preds[:, :, 6:],
                                                    reg_targets[:, :, -1:])
        rpn_angle_loss = angle_loss * reg_weights

        # split reg loss
        dim_loss = rpn_reg_loss[:, :, :3].sum(dim=-1).sum(
            dim=-1) / num_reg_coeff
        pos_loss = rpn_reg_loss[:, :,
                                3:6].sum(dim=-1).sum(dim=-1) / num_reg_coeff
        angle_loss = rpn_angle_loss.sum(dim=-1).sum(dim=-1) / num_reg_coeff

        prediction_dict['rcnn_reg_weights'] = reg_weights

        loss_dict = {}

        loss_dict['cls_loss'] = rpn_cls_loss
        # loss_dict['rpn_bbox_loss'] = rpn_reg_loss
        # split bbox loss instead of fusing them
        loss_dict['dim_loss'] = dim_loss
        loss_dict['pos_loss'] = pos_loss
        loss_dict['angle_loss'] = angle_loss

        self.profiler.end('8')

        #################################
        # First stage loss
        #################################
        preds = {
            'pred_boxes_3d': prediction_dict['rpn_boxes_3d'],
            'pred_probs_3d': prediction_dict['rpn_probs_preds']
        }
        targets = {'cls_targets': cls_targets, 'reg_targets': reg_targets}
        weights = {'reg_weights': reg_weights, 'cls_weights': cls_weights}
        loss_dict.update(self.rpn_loss(preds, targets, weights))

        ###################################
        # Statistic
        ###################################

        voxel_centers = self.voxel_generator.voxel_centers
        D = self.voxel_generator.lattice_dims[1]
        voxel_centers = voxel_centers.view(-1, D, 3)[:, 0, :]
        # import ipdb
        # ipdb.set_trace()
        # decode bbox
        pred_boxes_3d = self.bbox_coder.decode_batch_bbox(
            voxel_centers, rpn_bbox_preds[:, :, :6])
        # decode angle
        angles_oritations = self.bbox_coder.decode_batch_angle(
            rpn_bbox_preds[:, :, 6:], self.angle_loss.bin_centers,
            self.num_bins)
        pred_boxes_3d = torch.cat([pred_boxes_3d, angles_oritations], dim=-1)

        # import ipdb
        # ipdb.set_trace()
        # select the top n
        order = torch.sort(rpn_cls_probs, descending=True)[1]
        topn = 1000
        order = order[:, :topn]
        rpn_cls_probs = rpn_cls_probs[0][order[0]].unsqueeze(0)
        pred_boxes_3d = pred_boxes_3d[0][order[0]].unsqueeze(0)

        target = {
            'dimension': pred_boxes_3d[0, :, :3],
            'location': pred_boxes_3d[0, :, 3:6],
            'ry': pred_boxes_3d[0, :, 6]
        }

        boxes_2d = Projector.proj_box_3to2img(target, feed_dict['p2'])
        gt_boxes = feed_dict['gt_boxes']
        num_gt = gt_labels.numel()
        self.target_assigner.assign(boxes_2d, gt_boxes, eval_thresh=0.7)

        fake_match = self.target_assigner.analyzer.match
        # import ipdb
        # ipdb.set_trace()
        self.target_assigner.analyzer.analyze_ap(fake_match,
                                                 rpn_cls_probs,
                                                 num_gt,
                                                 thresh=0.1)

        # import ipdb
        # ipdb.set_trace()
        # angle stats
        angle_tp_mask = angle_tp_mask[reg_weights > 0]
        angles_tp_num = angle_tp_mask.int().sum().item()
        angles_all_num = angle_tp_mask.numel()

        self.target_assigner.stat.update({
            'cls_orient_2s_all_num': angles_all_num,
            'cls_orient_2s_tp_num': angles_tp_num
        })

        return loss_dict
예제 #15
0
class SemanticFasterRCNN(Model):
    def forward(self, feed_dict):
        self.clean_stats()

        prediction_dict = {}

        # base model
        base_feat = self.feature_extractor.first_stage_feature(
            feed_dict['img'])
        feed_dict.update({'base_feat': base_feat})
        # batch_size = base_feat.shape[0]

        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))

        # proposals = prediction_dict['proposals_batch']
        # shape(N,num_proposals,5)
        # pre subsample for reduce consume of memory
        if self.training:
            stats = self.pre_subsample(prediction_dict, feed_dict)
            self.stats.update(stats)
        rois_batch = prediction_dict['rois_batch']

        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        pooled_feat = self.rcnn_pooling(base_feat, rois_batch.view(-1, 5))

        # shape(N,C,1,1)
        pooled_feat = self.feature_extractor.second_stage_feature(pooled_feat)

        # semantic map
        # if self.use_self_attention:
        # pooled_feat_cls = pooled_feat.mean(3).mean(2)
        # rcnn_cls_scores = self.rcnn_cls_pred(pooled_feat_cls)
        # rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)

        # # self-attention
        # channel_attention = self.generate_channel_attention(pooled_feat)
        # spatial_attention = self.generate_spatial_attention(pooled_feat)
        # pooled_feat_reg = pooled_feat * channel_attention
        # pooled_feat_reg = pooled_feat * spatial_attention
        # pooled_feat_reg = pooled_feat_reg.mean(3).mean(2)

        # rcnn_bbox_preds = self.rcnn_bbox_pred(pooled_feat_reg)
        # else:
        rcnn_cls_scores_map = self.rcnn_cls_pred(pooled_feat)
        rcnn_cls_scores = rcnn_cls_scores_map.mean(3).mean(2)
        saliency_map = F.softmax(rcnn_cls_scores_map, dim=1)
        rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)
        # rcnn_cls_probs = rcnn_cls_probs_map.mean(3).mean(2)
        # shape(N,C)
        rcnn_bbox_feat = pooled_feat * saliency_map[:, 1:, :, :]
        # rcnn_bbox_feat = torch.cat([rcnn_bbox_feat, pooled_feat], dim=1)
        rcnn_bbox_feat = rcnn_bbox_feat.mean(3).mean(2)

        # if self.use_score:
        # pooled_feat =

        rcnn_bbox_preds = self.rcnn_bbox_pred(rcnn_bbox_feat)

        prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds
        prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores

        # used for track
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][
            proposals_order]

        pred_boxes = self.bbox_coder.decode_batch(
            rcnn_bbox_preds.view(1, -1, 4), rois_batch[:, :, 1:5])
        rcnn_rois_batch = torch.zeros_like(rois_batch)
        rcnn_rois_batch[:, :, 1:5] = pred_boxes.detach()
        prediction_dict['rcnn_rois_batch'] = rcnn_rois_batch

        # if self.training:
        # # append gt
        # rcnn_rois_batch = self.append_gt(rcnn_rois_batch,
        # feed_dict['gt_boxes'])
        # prediction_dict['rcnn_rois_batch'] = rcnn_rois_batch

        ###################################
        # stats
        ###################################

        # when enable cls, skip it
        stats = self.target_assigner.assign(rcnn_rois_batch[:, :, 1:],
                                            feed_dict['gt_boxes'],
                                            feed_dict['gt_labels'])[-1]
        self.rcnn_stats.update(stats)

        # analysis ap
        # when enable cls, otherwise it is no sense
        if self.training:
            rcnn_cls_probs = prediction_dict['rcnn_cls_probs']
            num_gt = feed_dict['gt_labels'].numel()
            fake_match = self.rcnn_stats['match']
            stats = self.target_assigner.analyzer.analyze_ap(fake_match,
                                                             rcnn_cls_probs[:,
                                                                            1],
                                                             num_gt,
                                                             thresh=0.5)
            # collect stats
            self.rcnn_stats.update(stats)

        return prediction_dict

    def clean_stats(self):
        # rois bbox
        self.stats = {
            'num_det': 1,
            'num_tp': 0,
            'matched_thresh': 0,
            'recall_thresh': 0,
            'match': None,
            # 'matched': 0,
            # 'num_gt': 1,
        }

        # rcnn bbox(final bbox)
        self.rcnn_stats = {
            'num_det': 1,
            'num_tp': 0,
            'matched_thresh': 0,
            'recall_thresh': 0,
            'match': None,
            # 'matched': 0,
        }

    def generate_channel_attention(self, feat):
        return feat.mean(3, keepdim=True).mean(2, keepdim=True)

    def generate_spatial_attention(self, feat):
        return self.spatial_attention(feat)

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)

    def init_modules(self):
        self.feature_extractor = feature_extractors_builder.build(
            self.feature_extractor_config)
        # self.feature_extractor = ResNetFeatureExtractor(
        # self.feature_extractor_config)
        # self.feature_extractor = MobileNetFeatureExtractor(
        # self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        if self.use_self_attention:
            self.rcnn_cls_pred = nn.Linear(self.ndin, self.n_classes)
        else:
            self.rcnn_cls_pred = nn.Conv2d(self.ndin, self.n_classes, 3, 1, 1)
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(self.ndin, 4)
            # self.rcnn_bbox_pred = nn.Conv2d(2048,4,3,1,1)
        else:
            self.rcnn_bbox_pred = nn.Linear(self.ndin, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2)
        else:
            self.rcnn_cls_loss = functools.partial(F.cross_entropy,
                                                   reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

        # attention
        if self.use_self_attention:
            self.spatial_attention = nn.Conv2d(self.ndin, 1, 3, 1, 1)

    def init_param(self, model_config):
        if model_config.get('din'):
            self.ndin = model_config['din']
        else:
            self.ndin = 512
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']
        self.use_self_attention = model_config.get('use_self_attention')

        # some submodule config
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # bbox_coder
        self.bbox_coder = self.target_assigner.bbox_coder

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])

    def pre_subsample(self, prediction_dict, feed_dict):
        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights, stats = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        cls_criterion = None
        pos_indicator = rcnn_reg_weights > 0
        indicator = rcnn_cls_weights > 0

        # subsample from all
        # shape (N,M)
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        # check
        assert num_cls_coeff, 'bug happens'
        num_reg_coeff = torch.max(num_reg_coeff,
                                  torch.ones_like(num_reg_coeff))

        prediction_dict[
            'rcnn_cls_weights'] = rcnn_cls_weights / num_cls_coeff.float()
        prediction_dict[
            'rcnn_reg_weights'] = rcnn_reg_weights / num_reg_coeff.float()
        prediction_dict['rcnn_cls_targets'] = rcnn_cls_targets[
            batch_sampled_mask]
        prediction_dict['rcnn_reg_targets'] = rcnn_reg_targets[
            batch_sampled_mask]
        prediction_dict['fake_match'] = self.target_assigner.analyzer.match[
            batch_sampled_mask]

        # update rois_batch
        prediction_dict['rois_batch'] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)

        stats['match'] = stats['match'][batch_sampled_mask]

        return stats

    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        # targets and weights
        rcnn_cls_weights = prediction_dict['rcnn_cls_weights']
        rcnn_reg_weights = prediction_dict['rcnn_reg_weights']

        rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
        rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

        # classification loss
        rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
        rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores, rcnn_cls_targets)
        rcnn_cls_loss *= rcnn_cls_weights
        rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

        # bounding box regression L1 loss
        rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                             rcnn_reg_targets).sum(dim=-1)
        rcnn_bbox_loss *= rcnn_reg_weights
        # rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

        # loss weights has no gradients
        loss_dict['rcnn_cls_loss'] = rcnn_cls_loss
        loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        # add rcnn_cls_targets to get the statics of rpn
        # loss_dict['rcnn_cls_targets'] = rcnn_cls_targets

        # analysis ap
        rcnn_cls_probs = prediction_dict['rcnn_cls_probs']
        num_gt = feed_dict['gt_labels'].numel()
        fake_match = prediction_dict['fake_match']
        self.target_assigner.analyzer.analyze_ap(fake_match,
                                                 rcnn_cls_probs[:, 1],
                                                 num_gt,
                                                 thresh=0.5)

        return loss_dict
class CascadeFasterRCNN(Model):
    def forward(self, feed_dict):
        import ipdb
        ipdb.set_trace()
        prediction_dict = {}

        # base model
        base_feat = self.feature_extractor.first_stage_feature(
            feed_dict['img'])
        feed_dict.update({'base_feat': base_feat})
        # batch_size = base_feat.shape[0]

        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))

        # proposals = prediction_dict['proposals_batch']
        # shape(N,num_proposals,5)
        # pre subsample for reduce consume of memory
        if self.training:
            self.pre_subsample(prediction_dict, feed_dict)
        rois_batch = prediction_dict['rois_batch']

        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        pooled_feat = self.rcnn_pooling(base_feat, rois_batch.view(-1, 5))

        # shape(N,C,1,1)
        pooled_feat = self.feature_extractor.second_stage_feature(pooled_feat)
        # shape(N,C)
        if self.reduce:
            pooled_feat = pooled_feat.mean(3).mean(2)
        else:
            pooled_feat = pooled_feat.view(self.rcnn_batch_size, -1)

        rcnn_bbox_preds = self.rcnn_bbox_pred(pooled_feat)
        rcnn_cls_scores = self.rcnn_cls_pred(pooled_feat)

        rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)

        prediction_dict['rcnn_cls_probs_1'] = rcnn_cls_probs
        prediction_dict['rcnn_bbox_preds_1'] = rcnn_bbox_preds
        prediction_dict['rcnn_cls_scores_1'] = rcnn_cls_scores

        # used for track
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][0][
            proposals_order]

        # return prediction_dict
        ###########################
        # second stage
        ###########################
        if self.training:
            self.pre_subsample(prediction_dict, feed_dict, stage_idx=2)

        rois_batch_2 = prediction_dict['rois_batch']
        # proposal
        rois_batch_2 = Proposal.apply(rcnn_cls_probs, rois_batch_2,
                                      rcnn_bbox_preds, feed_dict['im_info'])

        # pooling
        pooled_feat_2 = self.rcnn_pooling(base_feat, rois_batch_2.view(-1, 5))
        pooled_feat_2 = pooled_feat_2.mean(3).mean(2)

        # rcnn conv
        pooled_feat_2 = self.feature_extractor.third_stage_feature(
            pooled_feat_2)

        # reg and cls
        rcnn_bbox_preds_2 = self.rcnn_bbox_pred(pooled_feat_2)

        rcnn_cls_scores_2 = self.rcnn_cls_pred(pooled_feat_2)
        rcnn_cls_probs_2 = F.softmax(rcnn_cls_scores_2, dim=1)

        prediction_dict['rcnn_cls_probs_2'] = rcnn_cls_probs_2
        prediction_dict['rcnn_bbox_preds_2'] = rcnn_bbox_preds_2
        prediction_dict['rcnn_cls_scores_2'] = rcnn_cls_scores_2

        # compatible with train.py
        prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs_2
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds_2
        prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores
        return prediction_dict

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)

    def init_modules(self):
        self.feature_extractor = ResNetFeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        self.rcnn_cls_pred = nn.Linear(2048, self.n_classes)
        if self.reduce:
            in_channels = 2048
        else:
            in_channels = 2048 * 4 * 4
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4)
        else:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4 * self.n_classes)
        if self.use_cascade:
            self.rcnn_cls_pred_2 = copy.deepcopy(self.rcnn_cls_pred)
            self.rcnn_bbox_pred_2 = copy.deepcopy(self.rcnn_bbox_pred)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2)
        else:
            self.rcnn_cls_loss = functools.partial(
                F.cross_entropy, reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']
        self.use_cascade = model_config.get('use_cascade')

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']
        self.fg_thresh_arr = model_config['fg_thresh_arr']
        self.bg_thresh_arr = model_config['bg_thresh_arr']

        # some submodule config
        self.feature_extractor_config = model_config['feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # bbox_coder
        self.bbox_coder = self.target_assigner.bbox_coder

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])

        # self.reduce = model_config.get('reduce')
        self.reduce = True

    def pre_subsample(self, prediction_dict, feed_dict, stage_idx=0):
        # if stage_idx:
        # rois_batch = prediction_dict['rois_batch_' + str(stage_idx)]
        # else:
        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        # import ipdb
        # ipdb.set_trace()
        self.target_assigner.fg_thresh = self.fg_thresh_arr[stage_idx]
        self.target_assigner.bg_thresh = self.bg_thresh_arr[stage_idx]
        stage_idx = str(stage_idx)
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        cls_criterion = None
        pos_indicator = rcnn_cls_targets > 0
        indicator = rcnn_cls_weights > 0

        # subsample from all
        # shape (N,M)
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        num_cls_coeff = rcnn_cls_weights.type(torch.cuda.ByteTensor).sum(
            dim=-1)
        num_reg_coeff = rcnn_reg_weights.type(torch.cuda.ByteTensor).sum(
            dim=-1)
        # check
        assert num_cls_coeff, 'bug happens'
        assert num_reg_coeff, 'bug happens'

        prediction_dict['rcnn_cls_weights_' +
                        stage_idx] = rcnn_cls_weights / num_cls_coeff.float()
        prediction_dict['rcnn_reg_weights_' +
                        stage_idx] = rcnn_reg_weights / num_reg_coeff.float()
        prediction_dict['rcnn_cls_targets_' + stage_idx] = rcnn_cls_targets[
            batch_sampled_mask]
        prediction_dict['rcnn_reg_targets_' + stage_idx] = rcnn_reg_targets[
            batch_sampled_mask]

        # update rois_batch
        prediction_dict['rois_batch'] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)

        if not self.training:
            # used for track
            proposals_order = prediction_dict['proposals_order']

            prediction_dict['proposals_order'] = proposals_order[
                batch_sampled_mask]

    def loss(self, prediction_dict, feed_dict, num_stage=2):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        for stage_idx in range(1, num_stage):
            # targets and weights
            rcnn_cls_weights = prediction_dict['rcnn_cls_weights_' + stage_idx]
            rcnn_reg_weights = prediction_dict['rcnn_reg_weights_' + stage_idx]

            rcnn_cls_targets = prediction_dict['rcnn_cls_targets_' + stage_idx]
            rcnn_reg_targets = prediction_dict['rcnn_reg_targets_' + stage_idx]

            # classification loss
            rcnn_cls_scores = prediction_dict['rcnn_cls_scores_' + stage_idx]
            rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores,
                                               rcnn_cls_targets)
            rcnn_cls_loss *= rcnn_cls_weights
            rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

            # bounding box regression L1 loss
            rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds_' + stage_idx]
            rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                                 rcnn_reg_targets).sum(dim=-1)
            rcnn_bbox_loss *= rcnn_reg_weights
            # rcnn_bbox_loss *= rcnn_reg_weights
            rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

            # loss weights has no gradients
            loss_dict['rcnn_cls_loss_' + stage_idx] = rcnn_cls_loss
            loss_dict['rcnn_bbox_loss_' + stage_idx] = rcnn_bbox_loss

            # add rcnn_cls_targets to get the statics of rpn
            loss_dict['rcnn_cls_targets_'] = rcnn_cls_targets

        return loss_dict
예제 #17
0
class SSDModel(Model):
    def init_param(self, model_config):
        self.feature_extractor_config = model_config['feature_extractor_config']
        self.multibox_cfg = [3, 3, 3, 3, 3, 3]
        self.n_classes = len(model_config['classes'])
        self.sampler = DetectionSampler(model_config['sampler_config'])
        self.batch_size = model_config['batch_size']
        self.use_focal_loss = model_config['use_focal_loss']
        # self.multibox_cfg = model_config['multibox_config']

        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # import ipdb
        # ipdb.set_trace()
        self.anchor_generator = AnchorGenerator(
            model_config['anchor_generator_config'])

        self.bbox_coder = self.target_assigner.bbox_coder

        # self.priorsbox = PriorBox(model_config['anchor_generator_config'])

    def init_modules(self):
        self.feature_extractor = PyramidVggnetExtractor(
            self.feature_extractor_config)

        # loc layers and conf layers
        base_feat = self.feature_extractor.base_feat
        extra_layers = self.feature_extractor.extras_layers
        loc_layers, conf_layers = self.make_multibox(base_feat, extra_layers)
        self.loc_layers = loc_layers
        self.conf_layers = conf_layers

        # self.rcnn_3d_preds = nn.Linear()

        # loss layers
        self.loc_loss = nn.SmoothL1Loss(reduce=False)

        if self.use_focal_loss:
            self.conf_loss = FocalLoss(
                self.n_classes, alpha=0.2, gamma=2, auto_alpha=False)
        else:
            self.conf_loss = nn.CrossEntropyLoss(reduce=False)

    def make_multibox(self, vgg, extra_layers):
        cfg = self.multibox_cfg
        num_classes = self.n_classes
        loc_layers = []
        conf_layers = []
        vgg_source = [21, -2]
        for k, v in enumerate(vgg_source):
            loc_layers += [
                nn.Conv2d(
                    vgg[v].out_channels, cfg[k] * 4, kernel_size=3, padding=1)
            ]
            conf_layers += [
                nn.Conv2d(
                    vgg[v].out_channels,
                    cfg[k] * num_classes,
                    kernel_size=3,
                    padding=1)
            ]
        for k, v in enumerate(extra_layers[1::2], 2):
            loc_layers += [
                nn.Conv2d(
                    v.out_channels, cfg[k] * 4, kernel_size=3, padding=1)
            ]
            conf_layers += [
                nn.Conv2d(
                    v.out_channels,
                    cfg[k] * num_classes,
                    kernel_size=3,
                    padding=1)
            ]
        return nn.ModuleList(loc_layers), nn.ModuleList(conf_layers)

    def init_weights(self):
        pass

    def forward(self, feed_dict):
        img = feed_dict['img']
        source_feats = self.feature_extractor(img)
        loc_preds = []
        conf_preds = []

        featmap_shapes = []

        # apply multibox head to source layers
        for (x, l, c) in zip(source_feats, self.loc_layers, self.conf_layers):
            loc_preds.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf_preds.append(c(x).permute(0, 2, 3, 1).contiguous())
            featmap_shapes.append(x.size()[-2:])

        # import ipdb
        # ipdb.set_trace()
        loc_preds = torch.cat([o.view(o.size(0), -1) for o in loc_preds], 1)
        conf_preds = torch.cat([o.view(o.size(0), -1) for o in conf_preds], 1)
        probs = F.softmax(
            conf_preds.view(conf_preds.size(0), -1, self.n_classes), dim=-1)
        loc_preds = loc_preds.view(loc_preds.size(0), -1, 4)

        # import ipdb
        # ipdb.set_trace()
        anchors = self.anchor_generator.generate_pyramid(featmap_shapes)
        # anchors = self.priorsbox.forward(featmap_shapes)

        # import ipdb
        # ipdb.set_trace()
        rois_batch_inds = torch.zeros_like(loc_preds[:, :, -1:])
        rois_batch = torch.cat([rois_batch_inds, anchors.unsqueeze(0)], dim=-1)
        second_rpn_anchors = anchors.unsqueeze(0)

        rcnn_3d = torch.zeros_like(loc_preds)

        prediction_dict = {
            'rcnn_bbox_preds': loc_preds,
            'rcnn_cls_scores': conf_preds,
            'anchors': anchors,
            'rcnn_cls_probs': probs,
            'rois_batch': rois_batch,
            'second_rpn_anchors': second_rpn_anchors,
            'rcnn_3d': rcnn_3d
        }
        return prediction_dict

    def loss(self, prediction_dict, feed_dict):
        # import ipdb
        # ipdb.set_trace()
        # loss for cls
        loss_dict = {}

        gt_boxes = feed_dict['gt_boxes']

        anchors = prediction_dict['anchors']

        #################################
        # target assigner
        ################################
        # no need gt labels here,it just a binary classifcation problem
        # import ipdb
        # ipdb.set_trace()
        rpn_cls_targets, rpn_reg_targets, \
            rpn_cls_weights, rpn_reg_weights = \
            self.target_assigner.assign(anchors, gt_boxes, gt_labels=None)

        ################################
        # subsample
        ################################

        pos_indicator = rpn_reg_weights > 0
        indicator = rpn_cls_weights > 0

        rpn_cls_probs = prediction_dict['rcnn_cls_probs'][:, :, 1]
        cls_criterion = rpn_cls_probs

        batch_sampled_mask = self.sampler.subsample_batch(
            self.batch_size,
            pos_indicator,
            criterion=cls_criterion,
            indicator=indicator)
        batch_sampled_mask = batch_sampled_mask.type_as(rpn_cls_weights)
        rpn_cls_weights = rpn_cls_weights * batch_sampled_mask
        rpn_reg_weights = rpn_reg_weights * batch_sampled_mask
        num_cls_coeff = (rpn_cls_weights > 0).sum(dim=1)
        num_reg_coeff = (rpn_reg_weights > 0).sum(dim=1)
        # check
        #  assert num_cls_coeff, 'bug happens'
        #  assert num_reg_coeff, 'bug happens'
        if num_cls_coeff == 0:
            num_cls_coeff = torch.ones([]).type_as(num_cls_coeff)
        if num_reg_coeff == 0:
            num_reg_coeff = torch.ones([]).type_as(num_reg_coeff)

        # cls loss
        rpn_cls_score = prediction_dict['rcnn_cls_scores']
        # rpn_cls_loss = self.rpn_cls_loss(rpn_cls_score, rpn_cls_targets)
        rpn_cls_loss = self.conf_loss(
            rpn_cls_score.view(-1, 2), rpn_cls_targets.view(-1))
        rpn_cls_loss = rpn_cls_loss.view_as(rpn_cls_weights)
        rpn_cls_loss *= rpn_cls_weights
        rpn_cls_loss = rpn_cls_loss.sum(dim=1) / num_cls_coeff.float()

        # bbox loss
        # shape(N,num,4)
        rpn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        # rpn_bbox_preds = rpn_bbox_preds.permute(0, 2, 3, 1).contiguous()
        # shape(N,H*W*num_anchors,4)
        # rpn_bbox_preds = rpn_bbox_preds.view(rpn_bbox_preds.shape[0], -1, 4)
        # import ipdb
        # ipdb.set_trace()
        rpn_reg_loss = self.loc_loss(rpn_bbox_preds, rpn_reg_targets)
        rpn_reg_loss *= rpn_reg_weights.unsqueeze(-1).expand(-1, -1, 4)
        rpn_reg_loss = rpn_reg_loss.view(rpn_reg_loss.shape[0], -1).sum(
            dim=1) / num_reg_coeff.float()

        prediction_dict['rcnn_reg_weights'] = rpn_reg_weights[
            batch_sampled_mask > 0]

        loss_dict['rpn_cls_loss'] = rpn_cls_loss
        loss_dict['rpn_bbox_loss'] = rpn_reg_loss

        # recall
        final_boxes = self.bbox_coder.decode_batch(rpn_bbox_preds, anchors)
        self.target_assigner.assign(final_boxes, gt_boxes)
        return loss_dict
예제 #18
0
class RPNModel(Model):
    def init_param(self, model_config):
        self.in_channels = model_config['din']
        self.post_nms_topN = model_config['post_nms_topN']
        self.pre_nms_topN = model_config['pre_nms_topN']
        self.nms_thresh = model_config['nms_thresh']
        self.use_score = model_config['use_score']
        self.rpn_batch_size = model_config['rpn_batch_size']
        self.use_focal_loss = model_config['use_focal_loss']

        # sampler
        # self.sampler = HardNegativeSampler(model_config['sampler_config'])
        # self.sampler = BalancedSampler(model_config['sampler_config'])
        self.sampler = DetectionSampler(model_config['sampler_config'])

        # anchor generator
        self.anchor_generator = AnchorGenerator(
            model_config['anchor_generator_config'])
        self.num_anchors = self.anchor_generator.num_anchors
        self.nc_bbox_out = 4 * self.num_anchors
        self.nc_score_out = self.num_anchors * 2

        # target assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # bbox coder
        self.bbox_coder = self.target_assigner.bbox_coder

        self.use_iou = model_config.get('use_iou')

    def init_weights(self):
        self.truncated = False

        Filler.normal_init(self.rpn_conv, 0, 0.01, self.truncated)
        Filler.normal_init(self.rpn_cls_score, 0, 0.01, self.truncated)
        Filler.normal_init(self.rpn_bbox_pred, 0, 0.01, self.truncated)

    def init_modules(self):
        # define the convrelu layers processing input feature map
        self.rpn_conv = nn.Conv2d(self.in_channels, 512, 3, 1, 1, bias=True)

        # define bg/fg classifcation score layer
        self.rpn_cls_score = nn.Conv2d(512, self.nc_score_out, 1, 1, 0)

        # define anchor box offset prediction layer

        if self.use_score:
            bbox_feat_channels = 512 + 2
            self.nc_bbox_out /= self.num_anchors
        else:
            bbox_feat_channels = 512
        self.rpn_bbox_pred = nn.Conv2d(bbox_feat_channels, self.nc_bbox_out, 1,
                                       1, 0)

        # bbox
        self.rpn_bbox_loss = nn.modules.loss.SmoothL1Loss(reduce=False)

        # cls
        if self.use_focal_loss:
            self.rpn_cls_loss = FocalLoss(2)
        else:
            self.rpn_cls_loss = functools.partial(F.cross_entropy,
                                                  reduce=False)

    # def generate_proposal(self, rpn_cls_probs, anchors, rpn_bbox_preds,
    # im_info):
    # pass

    def forward(self, bottom_blobs):
        base_feat = bottom_blobs['base_feat']
        batch_size = base_feat.shape[0]
        gt_boxes = bottom_blobs['gt_boxes']
        im_info = bottom_blobs['im_info']

        # rpn conv
        rpn_conv = F.relu(self.rpn_conv(base_feat), inplace=True)

        # rpn cls score
        # shape(N,2*num_anchors,H,W)
        rpn_cls_scores = self.rpn_cls_score(rpn_conv)

        # rpn cls prob shape(N,2*num_anchors,H,W)
        rpn_cls_score_reshape = rpn_cls_scores.view(batch_size, 2, -1)
        rpn_cls_probs = F.softmax(rpn_cls_score_reshape, dim=1)
        rpn_cls_probs = rpn_cls_probs.view_as(rpn_cls_scores)
        # import ipdb
        # ipdb.set_trace()

        # rpn bbox pred
        # shape(N,4*num_anchors,H,W)
        if self.use_score:
            # shape (N,2,num_anchoros*H*W)
            rpn_cls_scores = rpn_cls_score_reshape.permute(0, 2, 1)
            rpn_bbox_preds = []
            for i in range(self.num_anchors):
                rpn_bbox_feat = torch.cat(
                    [rpn_conv, rpn_cls_scores[:, ::self.num_anchors, :, :]],
                    dim=1)
                rpn_bbox_preds.append(self.rpn_bbox_pred(rpn_bbox_feat))
            rpn_bbox_preds = torch.cat(rpn_bbox_preds, dim=1)
        else:
            # get rpn offsets to the anchor boxes
            rpn_bbox_preds = self.rpn_bbox_pred(rpn_conv)
            # rpn_bbox_preds = [rpn_bbox_preds]

        # generate anchors
        feature_map_list = [base_feat.size()[-2:]]
        anchors = self.anchor_generator.generate(feature_map_list)

        ###############################
        # Proposal
        ###############################
        # note that proposals_order is used for track transform of propsoals
        rois_batch, proposals_order = Proposal.apply(rpn_cls_probs, anchors,
                                                     rpn_bbox_preds, im_info)
        # batch_idx = torch.arange(batch_size).view(batch_size, 1).expand(
        # -1, proposals_batch.shape[1]).type_as(proposals_batch)
        # rois_batch = torch.cat((batch_idx.unsqueeze(-1), proposals_batch),
        # dim=2)

        if self.training:
            rois_batch = self.append_gt(rois_batch, gt_boxes)

        rpn_cls_scores = rpn_cls_scores.view(batch_size, 2, -1,
                                             rpn_cls_scores.shape[2],
                                             rpn_cls_scores.shape[3])
        rpn_cls_scores = rpn_cls_scores.permute(0, 3, 4, 2,
                                                1).contiguous().view(
                                                    batch_size, -1, 2)

        # postprocess
        rpn_cls_probs = rpn_cls_probs.view(batch_size, 2, -1,
                                           rpn_cls_probs.shape[2],
                                           rpn_cls_probs.shape[3])
        rpn_cls_probs = rpn_cls_probs.permute(0, 3, 4, 2, 1).contiguous().view(
            batch_size, -1, 2)
        predict_dict = {
            'rpn_cls_scores': rpn_cls_scores,
            'rois_batch': rois_batch,
            'anchors': anchors,

            # used for loss
            'rpn_bbox_preds': rpn_bbox_preds,
            'rpn_cls_probs': rpn_cls_probs,
            'proposals_order': proposals_order,
        }

        return predict_dict

    def append_gt(self, rois_batch, gt_boxes):
        ################################
        # append gt_boxes to rois_batch for losses
        ################################
        # may be some bugs here
        gt_boxes_append = torch.zeros(gt_boxes.shape[0], gt_boxes.shape[1],
                                      5).type_as(gt_boxes)
        gt_boxes_append[:, :, 1:5] = gt_boxes[:, :, :4]
        # cat gt_boxes to rois_batch
        rois_batch = torch.cat([rois_batch, gt_boxes_append], dim=1)
        return rois_batch

    def loss(self, prediction_dict, feed_dict):
        # loss for cls
        loss_dict = {}

        gt_boxes = feed_dict['gt_boxes']

        anchors = prediction_dict['anchors']

        assert len(anchors) == 1, 'just one feature maps is supported now'
        anchors = anchors[0]

        #################################
        # target assigner
        ################################
        # no need gt labels here,it just a binary classifcation problem
        #  import ipdb
        #  ipdb.set_trace()
        rpn_cls_targets, rpn_reg_targets, \
            rpn_cls_weights, rpn_reg_weights = \
            self.target_assigner.assign(anchors, gt_boxes, gt_labels=None)

        ################################
        # subsample
        ################################

        pos_indicator = rpn_reg_weights > 0
        indicator = rpn_cls_weights > 0

        if self.use_iou:
            cls_criterion = self.target_assigner.matcher.assigned_overlaps_batch
        else:
            rpn_cls_probs = prediction_dict['rpn_cls_probs'][:, :, 1]
            cls_criterion = rpn_cls_probs

        batch_sampled_mask = self.sampler.subsample_batch(
            self.rpn_batch_size,
            pos_indicator,
            criterion=cls_criterion,
            indicator=indicator)
        batch_sampled_mask = batch_sampled_mask.type_as(rpn_cls_weights)
        rpn_cls_weights = rpn_cls_weights * batch_sampled_mask
        rpn_reg_weights = rpn_reg_weights * batch_sampled_mask
        num_cls_coeff = (rpn_cls_weights > 0).sum(dim=1)
        num_reg_coeff = (rpn_reg_weights > 0).sum(dim=1)
        # check
        #  assert num_cls_coeff, 'bug happens'
        #  assert num_reg_coeff, 'bug happens'
        if num_cls_coeff == 0:
            num_cls_coeff = torch.ones([]).type_as(num_cls_coeff)
        if num_reg_coeff == 0:
            num_reg_coeff = torch.ones([]).type_as(num_reg_coeff)

        # cls loss
        rpn_cls_score = prediction_dict['rpn_cls_scores']
        # rpn_cls_loss = self.rpn_cls_loss(rpn_cls_score, rpn_cls_targets)
        rpn_cls_loss = self.rpn_cls_loss(rpn_cls_score.view(-1, 2),
                                         rpn_cls_targets.view(-1))
        rpn_cls_loss = rpn_cls_loss.view_as(rpn_cls_weights)
        rpn_cls_loss *= rpn_cls_weights
        rpn_cls_loss = rpn_cls_loss.sum(dim=1) / num_cls_coeff.float()

        # bbox loss
        # shape(N,num,4)
        rpn_bbox_preds = prediction_dict['rpn_bbox_preds']
        rpn_bbox_preds = rpn_bbox_preds.permute(0, 2, 3, 1).contiguous()
        # shape(N,H*W*num_anchors,4)
        rpn_bbox_preds = rpn_bbox_preds.view(rpn_bbox_preds.shape[0], -1, 4)
        rpn_reg_loss = self.rpn_bbox_loss(rpn_bbox_preds, rpn_reg_targets)
        rpn_reg_loss *= rpn_reg_weights.unsqueeze(-1).expand(-1, -1, 4)
        rpn_reg_loss = rpn_reg_loss.view(
            rpn_reg_loss.shape[0], -1).sum(dim=1) / num_reg_coeff.float()

        loss_dict['rpn_cls_loss'] = rpn_cls_loss
        loss_dict['rpn_bbox_loss'] = rpn_reg_loss
        return loss_dict
예제 #19
0
class NewSemanticFasterRCNN(Model):
    def forward(self, feed_dict):

        prediction_dict = {}

        # base model
        base_feat = self.feature_extractor.first_stage_feature(
            feed_dict['img'])
        feed_dict.update({'base_feat': base_feat})
        # batch_size = base_feat.shape[0]

        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))

        # proposals = prediction_dict['proposals_batch']
        # shape(N,num_proposals,5)
        # pre subsample for reduce consume of memory
        if self.training:
            self.pre_subsample(prediction_dict, feed_dict)
        rois_batch = prediction_dict['rois_batch']

        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        pooled_feat = self.rcnn_pooling(base_feat, rois_batch.view(-1, 5))

        # shape(N,C,1,1)
        pooled_feat = self.feature_extractor.second_stage_feature(pooled_feat)

        ##########################################
        # Semantic Map Generation
        ##########################################
        rcnn_cls_feat = pooled_feat.mean(3).mean(2)
        # shape(N,2*2048)
        # import ipdb
        # ipdb.set_trace()
        N = rcnn_cls_feat.shape[0]
        rcnn_cls_scores_attention = self.rcnn_cls_pred(rcnn_cls_feat)
        rcnn_cls_scores_attention_reduce = rcnn_cls_scores_attention.view(N, 2,
                                                                          -1)
        rcnn_cls_scores = rcnn_cls_scores_attention_reduce.mean(dim=-1)
        rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)
        rcnn_cls_scores_attention = rcnn_cls_scores_attention.view(N, 2, -1, 1,
                                                                   1)

        # semantic map
        #  rcnn_cls_scores_map = self.rcnn_cls_pred(pooled_feat)
        #  rcnn_cls_scores = rcnn_cls_scores_map.mean(3).mean(2)
        #  saliency_map = F.softmax(rcnn_cls_scores_map, dim=1)
        #  rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)
        # rcnn_cls_probs = rcnn_cls_probs_map.mean(3).mean(2)
        # shape(N,C)
        # pooled_feat: shape(N,2048,4,4)
        # attention: shape(N,)
        rcnn_bbox_feat = pooled_feat * rcnn_cls_scores_attention[:, 1, :, :, :]
        #  rcnn_bbox_feat = rcnn_bbox_feat.mean(3).mean(2)

        # if self.use_score:
        # pooled_feat =

        # import ipdb
        # ipdb.set_trace()

        rcnn_bbox_preds = self.rcnn_bbox_pred(rcnn_bbox_feat)
        rcnn_bbox_preds, _ = rcnn_bbox_preds.max(3)
        rcnn_bbox_preds, _ = rcnn_bbox_preds.max(2)

        prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds
        prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores

        # used for track
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][0][
            proposals_order]

        return prediction_dict

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)

    #  def rcnn_bbox_pred(self, pooled_feat):
    #  feat = self.bottle_neck(pooled_feat)
    #  feat = feat + pooled_feat
    #  return self.rcnn_bbox_pred_top(feat)

    def init_modules(self):
        self.feature_extractor = ResNetFeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        self.rcnn_cls_pred = nn.Linear(2048, self.n_classes * 2048)
        #  self.rcnn_cls_pred = nn.Conv2d(2048, self.n_classes, 3, 1, 1)
        if self.class_agnostic:
            #  self.bottle_neck = nn.Sequential(
            #  nn.Linear(2048, 512),
            #  nn.BatchNorm2d(512),
            #  nn.ReLU(inplace=True),
            #  nn.Linear(512, 2048))
            #  self.rcnn_bbox_pred_top = nn.Linear(2048, 4)
            # self.relu_top = nn.ReLU(inplace=True)
            self.rcnn_bbox_pred = nn.Conv2d(2048, 4, 3, 1, 1)
        else:
            self.rcnn_bbox_pred = nn.Linear(2048, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2)
        else:
            self.rcnn_cls_loss = functools.partial(
                F.cross_entropy, reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']

        # some submodule config
        self.feature_extractor_config = model_config['feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])

    def pre_subsample(self, prediction_dict, feed_dict):
        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        cls_criterion = None
        pos_indicator = rcnn_reg_weights > 0
        indicator = rcnn_cls_weights > 0

        # subsample from all
        # shape (N,M)
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        # check
        assert num_cls_coeff, 'bug happens'
        assert num_reg_coeff, 'bug happens'

        prediction_dict[
            'rcnn_cls_weights'] = rcnn_cls_weights / num_cls_coeff.float()
        prediction_dict[
            'rcnn_reg_weights'] = rcnn_reg_weights / num_reg_coeff.float()
        prediction_dict['rcnn_cls_targets'] = rcnn_cls_targets[
            batch_sampled_mask]
        prediction_dict['rcnn_reg_targets'] = rcnn_reg_targets[
            batch_sampled_mask]

        # update rois_batch
        prediction_dict['rois_batch'] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)

        if not self.training:
            # used for track
            proposals_order = prediction_dict['proposals_order']

            prediction_dict['proposals_order'] = proposals_order[
                batch_sampled_mask]

    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        # targets and weights
        rcnn_cls_weights = prediction_dict['rcnn_cls_weights']
        rcnn_reg_weights = prediction_dict['rcnn_reg_weights']

        rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
        rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

        # classification loss
        rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
        rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores, rcnn_cls_targets)
        rcnn_cls_loss *= rcnn_cls_weights
        rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

        # bounding box regression L1 loss
        rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                             rcnn_reg_targets).sum(dim=-1)
        rcnn_bbox_loss *= rcnn_reg_weights
        # rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

        # loss weights has no gradients
        loss_dict['rcnn_cls_loss'] = rcnn_cls_loss
        loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        # add rcnn_cls_targets to get the statics of rpn
        loss_dict['rcnn_cls_targets'] = rcnn_cls_targets

        return loss_dict
예제 #20
0
class GateFasterRCNN(Model):
    def forward(self, feed_dict):
        # import ipdb
        # ipdb.set_trace()

        prediction_dict = {}

        # base model
        base_feat = self.feature_extractor.first_stage_feature(
            feed_dict['img'])
        feed_dict.update({'base_feat': base_feat})
        # batch_size = base_feat.shape[0]

        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))

        # proposals = prediction_dict['proposals_batch']
        # shape(N,num_proposals,5)
        # pre subsample for reduce consume of memory
        if self.training:
            self.pre_subsample(prediction_dict, feed_dict)
        rois_batch = prediction_dict['rois_batch']

        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        pooled_feat = self.rcnn_pooling(base_feat, rois_batch.view(-1, 5))

        # shape(N,C,1,1)
        pooled_feat = self.feature_extractor.second_stage_feature(pooled_feat)
        # shape(N,C)
        pooled_feat = pooled_feat.mean(3).mean(2)

        rcnn_bbox_preds = self.rcnn_bbox_pred(pooled_feat)
        rcnn_cls_scores = self.rcnn_cls_pred(pooled_feat)

        rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)

        prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds
        prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores

        # used for track
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][0][
            proposals_order]

        return prediction_dict

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)

    def init_modules(self):
        self.feature_extractor = FeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = GateRPNModel(self.rpn_config)
        self.rcnn_pooling = RoIAlignAvg(self.pooling_size, self.pooling_size,
                                        1.0 / 16.0)
        self.rcnn_cls_pred = nn.Linear(2048, self.n_classes)
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(2048, 4)
        else:
            self.rcnn_bbox_pred = nn.Linear(2048, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2)
        else:
            self.rcnn_cls_loss = functools.partial(F.cross_entropy,
                                                   reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']

        # some submodule config
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # sampler
        # self.sampler = HardNegativeSampler(model_config['sampler_config'])
        # self.sampler = BalancedSampler(model_config['sampler_config'])
        self.sampler = DetectionSampler(model_config['sampler_config'])

    def pre_subsample(self, prediction_dict, feed_dict):
        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        pos_indicator = rcnn_cls_targets > 0
        indicator = rcnn_cls_weights > 0

        # subsample from all
        # shape (N,M)
        # use overlaps to subsample
        use_iou_for_criteron = True
        if use_iou_for_criteron:
            cls_criterion = self.target_assigner.matcher.assigned_overlaps_batch
        else:
            cls_criterion = None
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        num_cls_coeff = rcnn_cls_weights.type(
            torch.cuda.ByteTensor).sum(dim=-1)
        num_reg_coeff = rcnn_reg_weights.type(
            torch.cuda.ByteTensor).sum(dim=-1)
        # check
        assert num_cls_coeff, 'bug happens'
        assert num_reg_coeff, 'bug happens'

        prediction_dict[
            'rcnn_cls_weights'] = rcnn_cls_weights / num_cls_coeff.float()
        prediction_dict[
            'rcnn_reg_weights'] = rcnn_reg_weights / num_reg_coeff.float()
        prediction_dict['rcnn_cls_targets'] = rcnn_cls_targets[
            batch_sampled_mask]
        prediction_dict['rcnn_reg_targets'] = rcnn_reg_targets[
            batch_sampled_mask]

        # update rois_batch
        prediction_dict['rois_batch'] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)

        if not self.training:
            # used for track
            proposals_order = prediction_dict['proposals_order']

            prediction_dict['proposals_order'] = proposals_order[
                batch_sampled_mask]

    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        # targets and weights
        rcnn_cls_weights = prediction_dict['rcnn_cls_weights']
        rcnn_reg_weights = prediction_dict['rcnn_reg_weights']

        rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
        rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

        # classification loss
        rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
        rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores, rcnn_cls_targets)
        rcnn_cls_loss *= rcnn_cls_weights
        rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

        # bounding box regression L1 loss
        rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                             rcnn_reg_targets).sum(dim=-1)
        rcnn_bbox_loss *= rcnn_reg_weights
        # rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

        # loss weights has no gradients
        loss_dict['rcnn_cls_loss'] = rcnn_cls_loss
        loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        # add rcnn_cls_targets to get the statics of rpn
        loss_dict['rcnn_cls_targets'] = rcnn_cls_targets

        return loss_dict
예제 #21
0
class FPNFasterRCNN(Model):
    def calculate_roi_level(self, rois_batch):
        h = rois_batch[:, 4] - rois_batch[:, 2] + 1
        w = rois_batch[:, 3] - rois_batch[:, 1] + 1
        roi_level = torch.log(torch.sqrt(w * h) / 224.0)
        roi_level = torch.round(roi_level + 4)
        roi_level[roi_level < 2] = 2
        roi_level[roi_level > 5] = 5
        roi_level[...] = 4
        return roi_level

    def pyramid_rcnn_pooling(self, rcnn_feat_maps, rois_batch):
        pooled_feats = []
        # determine which layer to get feat
        roi_level = self.calculate_roi_level(rois_batch)
        for idx, rcnn_feat_map in enumerate(rcnn_feat_maps):
            idx += 2
            mask = roi_level == idx
            rois_batch_per_stage = rois_batch[mask]
            if rois_batch_per_stage.shape[0] == 0:
                continue
            pooled_feats.append(
                self.rcnn_pooling(rcnn_feat_map, rois_batch_per_stage))
        return torch.cat(pooled_feats, dim=0)

    def forward(self, feed_dict):

        prediction_dict = {}

        # base model
        rpn_feat_maps, rcnn_feat_maps, = self.feature_extractor.first_stage_feature(
            feed_dict['img'])
        feed_dict.update({'rpn_feat_maps': rpn_feat_maps})
        # batch_size = base_feat.shape[0]

        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))

        # proposals = prediction_dict['proposals_batch']
        # shape(N,num_proposals,5)
        # pre subsample for reduce consume of memory
        if self.training:
            self.pre_subsample(prediction_dict, feed_dict)
        rois_batch = prediction_dict['rois_batch']

        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        # pooled_feat = self.rcnn_pooling(rcnn_feat_maps, rois_batch.view(-1, 5))
        pooled_feat = self.pyramid_rcnn_pooling(rcnn_feat_maps,
                                                rois_batch.view(-1, 5))

        # shape(N,C,1,1)
        pooled_feat = self.feature_extractor.second_stage_feature(pooled_feat)
        # shape(N,C)
        if self.reduce:
            pooled_feat = pooled_feat.mean(3).mean(2)
        else:
            pooled_feat = pooled_feat.view(self.rcnn_batch_size, -1)

        rcnn_bbox_preds = self.rcnn_bbox_pred(pooled_feat)
        rcnn_cls_scores = self.rcnn_cls_pred(pooled_feat)

        rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)

        prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds
        prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores

        # used for track
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][
            proposals_order]

        return prediction_dict

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)

    def init_modules(self):
        self.feature_extractor = FPNFeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        self.rcnn_cls_pred = nn.Linear(1024, self.n_classes)
        if self.reduce:
            in_channels = 1024
        else:
            in_channels = 2048 * 4 * 4
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4)
        else:
            self.rcnn_bbox_pred = nn.Linear(in_channels, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2)
        else:
            self.rcnn_cls_loss = functools.partial(F.cross_entropy,
                                                   reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']

        # some submodule config
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])

        # self.reduce = model_config.get('reduce')
        self.reduce = True

    def pre_subsample(self, prediction_dict, feed_dict):
        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        cls_criterion = None
        pos_indicator = rcnn_reg_weights > 0
        indicator = rcnn_cls_weights > 0

        # subsample from all
        # shape (N,M)
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        # check
        assert num_cls_coeff, 'bug happens'
        assert num_reg_coeff, 'bug happens'

        prediction_dict[
            'rcnn_cls_weights'] = rcnn_cls_weights / num_cls_coeff.float()
        prediction_dict[
            'rcnn_reg_weights'] = rcnn_reg_weights / num_reg_coeff.float()
        prediction_dict['rcnn_cls_targets'] = rcnn_cls_targets[
            batch_sampled_mask]
        prediction_dict['rcnn_reg_targets'] = rcnn_reg_targets[
            batch_sampled_mask]

        # update rois_batch
        prediction_dict['rois_batch'] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)

        if not self.training:
            # used for track
            proposals_order = prediction_dict['proposals_order']

            prediction_dict['proposals_order'] = proposals_order[
                batch_sampled_mask]

    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        # targets and weights
        rcnn_cls_weights = prediction_dict['rcnn_cls_weights']
        rcnn_reg_weights = prediction_dict['rcnn_reg_weights']

        rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
        rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

        # classification loss
        rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
        rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores, rcnn_cls_targets)
        rcnn_cls_loss *= rcnn_cls_weights
        rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

        # bounding box regression L1 loss
        rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                             rcnn_reg_targets).sum(dim=-1)
        rcnn_bbox_loss *= rcnn_reg_weights
        # rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

        # loss weights has no gradients
        loss_dict['rcnn_cls_loss'] = rcnn_cls_loss
        loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        # add rcnn_cls_targets to get the statics of rpn
        # loss_dict['rcnn_cls_targets'] = rcnn_cls_targets

        return loss_dict
예제 #22
0
class LossFasterRCNN(Model):
    def forward(self, feed_dict):

        prediction_dict = {}

        # base model
        base_feat = self.feature_extractor.first_stage_feature(
            feed_dict['img'])
        feed_dict.update({'base_feat': base_feat})
        # batch_size = base_feat.shape[0]

        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))

        # proposals = prediction_dict['proposals_batch']
        # shape(N,num_proposals,5)
        # pre subsample for reduce consume of memory
        if self.training:
            self.pre_subsample(prediction_dict, feed_dict)
        rois_batch = prediction_dict['rois_batch']

        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        pooled_feat = self.rcnn_pooling(base_feat, rois_batch.view(-1, 5))

        # shape(N,C,1,1)
        pooled_feat = self.feature_extractor.second_stage_feature(pooled_feat)

        # semantic map
        rcnn_cls_scores_map = self.rcnn_cls_pred(pooled_feat)
        rcnn_cls_scores = rcnn_cls_scores_map.mean(3).mean(2)
        saliency_map = F.softmax(rcnn_cls_scores_map, dim=1)
        rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)
        # rcnn_cls_probs = rcnn_cls_probs_map.mean(3).mean(2)
        # shape(N,C)
        rcnn_bbox_feat = pooled_feat * saliency_map[:, 1:, :, :]
        # rcnn_bbox_feat = torch.cat([rcnn_bbox_feat, pooled_feat], dim=1)
        rcnn_bbox_feat = rcnn_bbox_feat.mean(3).mean(2)

        # if self.use_score:
        # pooled_feat =

        rcnn_bbox_preds = self.rcnn_bbox_pred(rcnn_bbox_feat)

        prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds
        prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores

        # used for track
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][0][
            proposals_order]

        return prediction_dict

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)

    def init_modules(self):
        self.feature_extractor = ResNetFeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = LossRPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        # self.rcnn_cls_pred = nn.Linear(2048, self.n_classes)
        self.rcnn_cls_pred = nn.Conv2d(2048, self.n_classes, 3, 1, 1)
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(2048, 4)
            # self.rcnn_bbox_pred = nn.Conv2d(2048,4,3,1,1)
        else:
            self.rcnn_bbox_pred = nn.Linear(2048, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(2)
        else:
            self.rcnn_cls_loss = functools.partial(F.cross_entropy,
                                                   reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

        # cluster loss for bbox and cls(feat)
        self.cluster_loss = ClusterLoss()

    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']

        # some submodule config
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])

    def get_cluster_loss(self, num_gt, bbox_feat, cls_feat):
        """
        Args:
            bbox_feat:(N,M,)
        """
        cluster_loss = 0
        # shape(N,M)
        match = self.target_assigner.matcher.match
        assert match.shape[0] == 1, 'only one num of batch is supported now'
        match = match[0]
        for i in range(num_gt):
            cluster_loss += self.cluster_loss()

    def pre_subsample(self, prediction_dict, feed_dict):
        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        cls_criterion = None
        pos_indicator = rcnn_reg_weights > 0
        indicator = rcnn_cls_weights > 0

        # subsample from all
        # shape (N,M)
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        # check
        assert num_cls_coeff, 'bug happens'
        assert num_reg_coeff, 'bug happens'

        prediction_dict[
            'rcnn_cls_weights'] = rcnn_cls_weights / num_cls_coeff.float()
        prediction_dict[
            'rcnn_reg_weights'] = rcnn_reg_weights / num_reg_coeff.float()
        prediction_dict['rcnn_cls_targets'] = rcnn_cls_targets[
            batch_sampled_mask]
        prediction_dict['rcnn_reg_targets'] = rcnn_reg_targets[
            batch_sampled_mask]

        # update rois_batch
        prediction_dict['rois_batch'] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)

        if not self.training:
            # used for track
            proposals_order = prediction_dict['proposals_order']

            prediction_dict['proposals_order'] = proposals_order[
                batch_sampled_mask]
        # mask assignments like as before

        match = self.target_assigner.matcher.assignments
        prediction_dict['match'] = match[batch_sampled_mask]

    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        # targets and weights
        rcnn_cls_weights = prediction_dict['rcnn_cls_weights']
        rcnn_reg_weights = prediction_dict['rcnn_reg_weights']

        rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
        rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

        # classification loss
        rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
        rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores, rcnn_cls_targets)
        rcnn_cls_loss *= rcnn_cls_weights
        rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

        # bounding box regression L1 loss
        rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                             rcnn_reg_targets).sum(dim=-1)
        rcnn_bbox_loss *= rcnn_reg_weights
        # rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

        #  rcnn_cls_feat_single = rcnn_cls_scores
        rcnn_bbox_pred_single = rcnn_bbox_preds
        #  cluster_cls_loss = 0
        cluster_bbox_loss = 0
        num_gt = feed_dict['gt_boxes'].shape[1]
        match = prediction_dict['match']
        # import ipdb
        # ipdb.set_trace()
        for i in range(num_gt):
            #  cls_feat = rcnn_cls_feat_single[match == i]
            #  cluster_cls_loss += self.cluster_loss(cls_feat)
            bbox_feat = rcnn_bbox_pred_single[match == i]
            cluster_bbox_loss += self.cluster_loss(bbox_feat)

        # loss weights has no gradients
        loss_dict['rcnn/cls_loss'] = rcnn_cls_loss
        loss_dict['rcnn/bbox_loss'] = rcnn_bbox_loss
        loss_dict['rpn/cluster_bbox_loss'] = cluster_bbox_loss
        #  loss_dict['rpn/cluster_cls_loss'] = cluster_cls_loss

        return loss_dict
예제 #23
0
class SemanticFasterRCNN(Model):
    def forward(self, feed_dict):
        #  import ipdb
        #  ipdb.set_trace()
        prediction_dict = {}

        self.profiler.start('base_model')
        # base model
        base_feat = self.feature_extractor.first_stage_feature(
            feed_dict['img'])
        feed_dict.update({'base_feat': base_feat})
        self.profiler.end('base_model')
        # batch_size = base_feat.shape[0]

        self.profiler.start('rpn')
        # rpn model
        prediction_dict.update(self.rpn_model.forward(feed_dict))
        self.profiler.end('rpn')

        # proposals = prediction_dict['proposals_batch']
        # shape(N,num_proposals,5)
        # pre subsample for reduce consume of memory
        if self.training:
            self.pre_subsample(prediction_dict, feed_dict)
        rois_batch = prediction_dict['rois_batch']

        self.profiler.start('roipooling')
        # note here base_feat (N,C,H,W),rois_batch (N,num_proposals,5)
        pooled_feat = self.rcnn_pooling(base_feat, rois_batch.view(-1, 5))
        self.profiler.end('roipooling')

        self.profiler.start('second_stage')
        # shape(N,C,1,1)
        pooled_feat = self.feature_extractor.second_stage_feature(pooled_feat)
        self.profiler.end('second_stage')

        # semantic map
        # if self.use_self_attention:
        # pooled_feat_cls = pooled_feat.mean(3).mean(2)
        # rcnn_cls_scores = self.rcnn_cls_pred(pooled_feat_cls)
        # rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)

        # # self-attention
        # channel_attention = self.generate_channel_attention(pooled_feat)
        # spatial_attention = self.generate_spatial_attention(pooled_feat)
        # pooled_feat_reg = pooled_feat * channel_attention
        # pooled_feat_reg = pooled_feat * spatial_attention
        # pooled_feat_reg = pooled_feat_reg.mean(3).mean(2)

        # rcnn_bbox_preds = self.rcnn_bbox_pred(pooled_feat_reg)
        # else:
        # rcnn_cls_scores_map = self.rcnn_cls_pred(pooled_feat)
        # rcnn_cls_scores = rcnn_cls_scores_map.mean(3).mean(2)
        # saliency_map = F.softmax(rcnn_cls_scores_map, dim=1)
        pooled_feat = pooled_feat.mean(3).mean(2)

        # rcnn_cls_probs = rcnn_cls_probs_map.mean(3).mean(2)
        # shape(N,C)
        # rcnn_bbox_feat = pooled_feat * saliency_map[:, 1:, :, :]
        # rcnn_bbox_feat = torch.cat([rcnn_bbox_feat, pooled_feat], dim=1)
        # rcnn_bbox_feat = rcnn_bbox_feat.mean(3).mean(2)

        # if self.use_score:
        # pooled_feat =

        rcnn_bbox_preds = self.rcnn_bbox_pred(pooled_feat)
        rcnn_cls_scores = self.rcnn_cls_pred(pooled_feat)
        rcnn_cls_probs = F.softmax(rcnn_cls_scores, dim=1)

        prediction_dict['rcnn_cls_probs'] = rcnn_cls_probs
        prediction_dict['rcnn_bbox_preds'] = rcnn_bbox_preds
        prediction_dict['rcnn_cls_scores'] = rcnn_cls_scores

        # used for track
        proposals_order = prediction_dict['proposals_order']
        prediction_dict['second_rpn_anchors'] = prediction_dict['anchors'][
            proposals_order]

        return prediction_dict

    def generate_channel_attention(self, feat):
        return feat.mean(3, keepdim=True).mean(2, keepdim=True)

    def generate_spatial_attention(self, feat):
        return self.spatial_attention(feat)

    def init_weights(self):
        # submodule init weights
        self.feature_extractor.init_weights()
        self.rpn_model.init_weights()

        Filler.normal_init(self.rcnn_cls_pred, 0, 0.01, self.truncated)
        Filler.normal_init(self.rcnn_bbox_pred, 0, 0.001, self.truncated)

    def init_modules(self):
        self.feature_extractor = ResNetFeatureExtractor(
            self.feature_extractor_config)
        self.rpn_model = RPNModel(self.rpn_config)
        if self.pooling_mode == 'align':
            self.rcnn_pooling = RoIAlignAvg(self.pooling_size,
                                            self.pooling_size, 1.0 / 16.0)
        elif self.pooling_mode == 'ps':
            self.rcnn_pooling = PSRoIPool(7, 7, 1.0 / 16, 7, self.n_classes)
        elif self.pooling_mode == 'psalign':
            raise NotImplementedError('have not implemented yet!')
        elif self.pooling_mode == 'deformable_psalign':
            raise NotImplementedError('have not implemented yet!')
        # if self.use_self_attention:
        self.rcnn_cls_pred = nn.Linear(2048, self.n_classes)
        # else:
        # self.rcnn_cls_pred = nn.Conv2d(2048, self.n_classes, 3, 1, 1)
        if self.class_agnostic:
            self.rcnn_bbox_pred = nn.Linear(2048, 4)
            # self.rcnn_bbox_pred = nn.Conv2d(2048,4,3,1,1)
        else:
            self.rcnn_bbox_pred = nn.Linear(2048, 4 * self.n_classes)

        # loss module
        if self.use_focal_loss:
            self.rcnn_cls_loss = FocalLoss(self.n_classes)
        else:
            self.rcnn_cls_loss = functools.partial(F.cross_entropy,
                                                   reduce=False)

        self.rcnn_bbox_loss = nn.modules.SmoothL1Loss(reduce=False)

        # attention
        if self.use_self_attention:
            self.spatial_attention = nn.Conv2d(2048, 1, 3, 1, 1)

    def init_param(self, model_config):
        classes = model_config['classes']
        self.classes = classes
        # including bg
        self.n_classes = len(classes) + 1
        self.class_agnostic = model_config['class_agnostic']
        self.pooling_size = model_config['pooling_size']
        self.pooling_mode = model_config['pooling_mode']
        self.crop_resize_with_max_pool = model_config[
            'crop_resize_with_max_pool']
        self.truncated = model_config['truncated']

        self.use_focal_loss = model_config['use_focal_loss']
        self.subsample_twice = model_config['subsample_twice']
        self.rcnn_batch_size = model_config['rcnn_batch_size']
        self.use_self_attention = model_config.get('use_self_attention')

        # some submodule config
        self.feature_extractor_config = model_config[
            'feature_extractor_config']
        self.rpn_config = model_config['rpn_config']

        # assigner
        self.target_assigner = TargetAssigner(
            model_config['target_assigner_config'])

        # sampler
        self.sampler = BalancedSampler(model_config['sampler_config'])

        self.profiler = Profiler()

    def pre_subsample(self, prediction_dict, feed_dict):
        rois_batch = prediction_dict['rois_batch']
        gt_boxes = feed_dict['gt_boxes']
        gt_labels = feed_dict['gt_labels']

        ##########################
        # assigner
        ##########################
        #  import ipdb
        #  ipdb.set_trace()
        rcnn_cls_targets, rcnn_reg_targets, rcnn_cls_weights, rcnn_reg_weights = self.target_assigner.assign(
            rois_batch[:, :, 1:], gt_boxes, gt_labels)

        ##########################
        # subsampler
        ##########################
        cls_criterion = None
        pos_indicator = rcnn_reg_weights > 0
        indicator = rcnn_cls_weights > 0

        # subsample from all
        # shape (N,M)
        batch_sampled_mask = self.sampler.subsample_batch(
            self.rcnn_batch_size,
            pos_indicator,
            indicator=indicator,
            criterion=cls_criterion)
        rcnn_cls_weights = rcnn_cls_weights[batch_sampled_mask]
        rcnn_reg_weights = rcnn_reg_weights[batch_sampled_mask]
        num_cls_coeff = (rcnn_cls_weights > 0).sum(dim=-1)
        num_reg_coeff = (rcnn_reg_weights > 0).sum(dim=-1)
        # check
        assert num_cls_coeff, 'bug happens'
        assert num_reg_coeff, 'bug happens'

        prediction_dict[
            'rcnn_cls_weights'] = rcnn_cls_weights / num_cls_coeff.float()
        prediction_dict[
            'rcnn_reg_weights'] = rcnn_reg_weights / num_reg_coeff.float()
        prediction_dict['rcnn_cls_targets'] = rcnn_cls_targets[
            batch_sampled_mask]
        prediction_dict['rcnn_reg_targets'] = rcnn_reg_targets[
            batch_sampled_mask]
        prediction_dict['fake_match'] = self.target_assigner.analyzer.match[
            batch_sampled_mask]

        # update rois_batch
        prediction_dict['rois_batch'] = rois_batch[batch_sampled_mask].view(
            rois_batch.shape[0], -1, 5)

        if not self.training:
            # used for track
            proposals_order = prediction_dict['proposals_order']

            prediction_dict['proposals_order'] = proposals_order[
                batch_sampled_mask]

    #  def umap_reg_targets():
    #  """
    #  expand rcnn_reg_targets(shape (N, 4) to shape(N, 4 * num_classes))
    #  """
    #  pass
    def squeeze_bbox_preds(self, rcnn_bbox_preds, rcnn_cls_targets):
        """
        squeeze rcnn_bbox_preds from shape (N, 4 * num_classes) to shape (N, 4)
        Args:
            rcnn_bbox_preds: shape(N, num_classes, 4)
            rcnn_cls_targets: shape(N, 1)
        """
        rcnn_bbox_preds = rcnn_bbox_preds.view(-1, self.n_classes, 4)
        batch_size = rcnn_bbox_preds.shape[0]
        offset = torch.arange(0, batch_size) * rcnn_bbox_preds.size(1)
        rcnn_cls_targets = rcnn_cls_targets + offset.type_as(rcnn_cls_targets)
        rcnn_bbox_preds = rcnn_bbox_preds.view(-1, 4)[rcnn_cls_targets]
        return rcnn_bbox_preds

    def loss(self, prediction_dict, feed_dict):
        """
        assign proposals label and subsample from them
        Then calculate loss
        """
        loss_dict = {}

        # submodule loss
        loss_dict.update(self.rpn_model.loss(prediction_dict, feed_dict))

        # targets and weights
        rcnn_cls_weights = prediction_dict['rcnn_cls_weights']
        rcnn_reg_weights = prediction_dict['rcnn_reg_weights']

        rcnn_cls_targets = prediction_dict['rcnn_cls_targets']
        rcnn_reg_targets = prediction_dict['rcnn_reg_targets']

        # classification loss
        rcnn_cls_scores = prediction_dict['rcnn_cls_scores']
        rcnn_cls_loss = self.rcnn_cls_loss(rcnn_cls_scores, rcnn_cls_targets)
        rcnn_cls_loss *= rcnn_cls_weights
        rcnn_cls_loss = rcnn_cls_loss.sum(dim=-1)

        # bounding box regression L1 loss
        rcnn_bbox_preds = prediction_dict['rcnn_bbox_preds']
        if not self.class_agnostic:
            rcnn_bbox_preds = self.squeeze_bbox_preds(rcnn_bbox_preds,
                                                      rcnn_cls_targets)
        rcnn_bbox_loss = self.rcnn_bbox_loss(rcnn_bbox_preds,
                                             rcnn_reg_targets).sum(dim=-1)
        rcnn_bbox_loss *= rcnn_reg_weights
        # rcnn_bbox_loss *= rcnn_reg_weights
        rcnn_bbox_loss = rcnn_bbox_loss.sum(dim=-1)

        # loss weights has no gradients
        loss_dict['rcnn_cls_loss'] = rcnn_cls_loss
        loss_dict['rcnn_bbox_loss'] = rcnn_bbox_loss

        # add rcnn_cls_targets to get the statics of rpn
        # loss_dict['rcnn_cls_targets'] = rcnn_cls_targets

        # analysis ap
        rcnn_cls_probs = prediction_dict['rcnn_cls_probs']
        num_gt = feed_dict['gt_labels'].numel()
        fake_match = prediction_dict['fake_match']
        self.target_assigner.analyzer.analyze_ap(fake_match,
                                                 rcnn_cls_probs[:, 1],
                                                 num_gt,
                                                 thresh=0.1)

        return loss_dict