コード例 #1
0
ファイル: FPN_test.py プロジェクト: xiangyanzhai/cascade_rcnn
    def __init__(self, config):
        super(Faster_Rcnn, self).__init__()
        self.config = config
        self.Mean = torch.tensor(config.Mean, dtype=torch.float32)
        self.num_anchor = len(config.anchor_scales) * len(config.anchor_ratios)
        self.anchors = []
        self.num_anchor = []
        for i in range(5):
            self.num_anchor.append(len(config.anchor_scales[i]) * len(config.anchor_ratios[i]))
            stride = 4 * 2 ** i
            print(stride, self.config.anchor_scales[i], self.config.anchor_ratios[i])
            anchors = get_anchors(np.ceil(self.config.img_max / stride + 1), self.config.anchor_scales[i],
                                  self.config.anchor_ratios[i], stride=stride)
            print(anchors.shape)
            self.anchors.append(anchors)

        self.PC = ProposalCreator(nms_thresh=config.roi_nms_thresh,
                                  n_train_pre_nms=config.roi_train_pre_nms,
                                  n_train_post_nms=config.roi_train_post_nms,
                                  n_test_pre_nms=config.roi_test_pre_nms, n_test_post_nms=config.roi_test_post_nms,
                                  min_size=config.roi_min_size)

        self.features = resnet101()
        self.fpn = FPN_net(256)
        self.rpn = RPN_net(256, self.num_anchor[0])
        self.fast = Fast_net(config.num_cls, 256 * 7 * 7, 1024)
        self.a = 0
        self.b = 0
        self.c = 0
        self.d = 0
        self.fast_num = 0
        self.fast_num_P = 0
コード例 #2
0
    def __init__(self, config):
        super(Mask_Rcnn, self).__init__()
        self.config = config
        self.Mean = torch.tensor(config.Mean, dtype=torch.float32)
        self.num_anchor = len(config.anchor_scales) * len(config.anchor_ratios)
        self.anchors = []
        self.num_anchor = []
        for i in range(5):
            self.num_anchor.append(
                len(config.anchor_scales[i]) * len(config.anchor_ratios[i]))
            stride = 4 * 2**i
            print(stride, self.config.anchor_scales[i],
                  self.config.anchor_ratios[i])
            anchors = get_anchors(np.ceil(self.config.img_max / stride + 1),
                                  self.config.anchor_scales[i],
                                  self.config.anchor_ratios[i],
                                  stride=stride)
            print(anchors.shape)
            self.anchors.append(anchors)
        self.ATC = AnchorTargetCreator(
            n_sample=config.rpn_n_sample,
            pos_iou_thresh=config.rpn_pos_iou_thresh,
            neg_iou_thresh=config.rpn_neg_iou_thresh,
            pos_ratio=config.rpn_pos_ratio)
        self.PC = ProposalCreator(nms_thresh=config.roi_nms_thresh,
                                  n_train_pre_nms=config.roi_train_pre_nms,
                                  n_train_post_nms=config.roi_train_post_nms,
                                  n_test_pre_nms=config.roi_test_pre_nms,
                                  n_test_post_nms=config.roi_test_post_nms,
                                  min_size=config.roi_min_size)
        self.PTC_1 = ProposalTargetCreator_box(
            n_sample=config.fast_n_sample,
            pos_ratio=config.fast_pos_ratio,
            pos_iou_thresh=config.fast_pos_iou_thresh,
            neg_iou_thresh_hi=config.fast_neg_iou_thresh_hi,
            neg_iou_thresh_lo=config.fast_neg_iou_thresh_lo)
        self.PTC_2 = ProposalTargetCreator_box(
            n_sample=config.fast_n_sample,
            pos_ratio=config.fast_pos_ratio,
            pos_iou_thresh=0.6,
            neg_iou_thresh_hi=0.6,
            neg_iou_thresh_lo=config.fast_neg_iou_thresh_lo)
        self.PTC = ProposalTargetCreator(
            n_sample=config.fast_n_sample,
            pos_ratio=config.fast_pos_ratio,
            pos_iou_thresh=0.7,
            neg_iou_thresh_hi=0.7,
            neg_iou_thresh_lo=config.fast_neg_iou_thresh_lo)

        self.features = resnet101()
        self.fpn = FPN_net(256)
        self.rpn = RPN_net(256, self.num_anchor[0])
        self.fast = Fast_net(config.num_cls, 256 * 7 * 7, 1024)
        self.fast_2 = Fast_net(config.num_cls, 256 * 7 * 7, 1024)
        self.fast_3 = Fast_net(config.num_cls, 256 * 7 * 7, 1024)
        self.mask_net = Mask_net(256, config.num_cls)
        self.a = 0
        self.b = 0
        self.c = 0
        self.d = 0
        self.fast_num = 0
        self.fast_num_P = 0

        self.loc_std1 = [1. / 10, 1. / 10, 1. / 5, 1. / 5]
        self.loc_std2 = [1. / 20, 1. / 20, 1. / 10, 1. / 10]
        self.loc_std3 = [1. / 30, 1. / 30, 1. / 15, 1. / 15]
        self.loss_weights = [1.0, 0.5, 0.25]