示例#1
0
    def __init__(self, classes, step2_model_path, step3_model_path, num_layers=101, fix_cnn_base=False, \
                class_agnostic=False, pretrained=False, base_model='resnet101'):
        super(resnet_step4, self).__init__()
        self.step2_model_path = step2_model_path
        self.step3_model_path = step3_model_path
        self.dout_base_model = 1024
        self.pretrained = pretrained
        self.class_agnostic = class_agnostic
        self.num_layer = num_layers
        self.fix_cnn_base = fix_cnn_base
        self.classes = classes
        self.n_classes = len(classes)
        self.base_model = 'resnet'+str(self.num_layer)

        # loss
        self.RCNN_loss_cls = 0
        self.RCNN_loss_bbox = 0

        #define base
        if self.num_layer == 101:
            resnet = resnet101()
        elif self.num_layer == 50:
            resnet = resnet50()
        elif self.num_layer == 152:
            resnet = resnet152()

        print("Step4: Loading pretrained weights from %s and %s" %(step2_model_path,step3_model_path))
        state_dict_step2 = torch.load(self.step2_model_path)
        state_dict_step3 = torch.load(self.step3_model_path)

        # not using the last maxpool layer
         # Build resnet.
        self.RCNN_base = nn.Sequential(resnet.conv1, resnet.bn1,resnet.relu,   
            resnet.maxpool,resnet.layer1,resnet.layer2,resnet.layer3)

        self.RCNN_base.load_state_dict({k.replace('RCNN_base.',''):v for k,v in state_dict_step3['model'].items() if 'RCNN_base' in k})
        for key, value in dict(self.RCNN_base.named_parameters()).items():
            value.requires_grad = False
        
        # define rpn
        self.RCNN_rpn = _RPN(self.dout_base_model)
        #init weight of rpn
        self.RCNN_rpn.load_state_dict({k.replace('RCNN_rpn.',''):v for k,v in state_dict_step3['model'].items() if 'RCNN_rpn' in k})
        for key, value in dict(self.RCNN_rpn.named_parameters()).items():
            value.requires_grad = False
        
        # define detector
        self.detector = _detector(self.classes, self.class_agnostic,pretrained, base_model=base_model)
        self.detector.load_state_dict({k.replace('detector.',''):v for k,v in state_dict_step2['model'].items() if 'detector' in k})
    def _init_modules(self):
        if self.num_layer == 101:
            resnet = resnet101()
        elif self.num_layer == 50:
            resnet = resnet50()
        elif self.num_layer == 152:
            resnet = resnet152()

        if self.pretrained == True:
            print("Loading pretrained weights from %s" % (self.model_path))
            state_dict = torch.load(self.model_path)
            resnet.load_state_dict({
                k: v
                for k, v in state_dict.items() if k in resnet.state_dict()
            })

        # Build resnet.
        self.RCNN_base = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
                                       resnet.maxpool, resnet.layer1,
                                       resnet.layer2, resnet.layer3)

        # Fix blocks
        for p in self.RCNN_base[0].parameters():
            p.requires_grad = False
        for p in self.RCNN_base[1].parameters():
            p.requires_grad = False

        assert (0 <= cfg.RESNET.FIXED_BLOCKS < 4)
        if cfg.RESNET.FIXED_BLOCKS >= 3:
            for p in self.RCNN_base[6].parameters():
                p.requires_grad = False
        if cfg.RESNET.FIXED_BLOCKS >= 2:
            for p in self.RCNN_base[5].parameters():
                p.requires_grad = False
        if cfg.RESNET.FIXED_BLOCKS >= 1:
            for p in self.RCNN_base[4].parameters():
                p.requires_grad = False

        def set_bn_fix(m):
            classname = m.__class__.__name__
            if classname.find('BatchNorm') != -1:
                for p in m.parameters():
                    p.requires_grad = False

        self.RCNN_base.apply(set_bn_fix)
示例#3
0
    def __init__(self, classes, num_layers=101, fix_cnn_base=False, \
                    class_agnostic=False, pretrained=False, base_model='resnet101'):
        super(resnet_step1, self).__init__()
        if num_layers == 101:
            self.model_path = 'data/pretrained_model/resnet101_caffe.pth'
        elif num_layers == 50:
            self.model_path = 'data/pretrained_model/resnet50_caffe.pth'
        elif num_layers == 152:
            self.model_path = 'data/pretrained_model/resnet152_caffe.pth'
        self.dout_base_model = 1024
        self.pretrained = pretrained
        self.class_agnostic = class_agnostic
        self.num_layer = num_layers
        self.fix_cnn_base = fix_cnn_base
        self.classes = classes
        self.n_classes = len(classes)

        # loss
        self.RCNN_loss_cls = 0
        self.RCNN_loss_bbox = 0

        #define base
        if self.num_layer == 101:
            resnet = resnet101()
        elif self.num_layer == 50:
            resnet = resnet50()
        elif self.num_layer == 152:
            resnet = resnet152()
        if self.pretrained:
            print("Step1: Loading pretrained weights from %s" %(self.model_path))
            state_dict = torch.load(self.model_path)
            resnet.load_state_dict({k:v for k,v in state_dict.items() if k in resnet.state_dict()})
        # Build resnet.
        self.RCNN_base = nn.Sequential(resnet.conv1, resnet.bn1,resnet.relu, \
            resnet.maxpool,resnet.layer1,resnet.layer2,resnet.layer3)
        
        # Fix blocks
        if fix_cnn_base:
            for p in self.RCNN_base[0].parameters(): p.requires_grad=False
            for p in self.RCNN_base[1].parameters(): p.requires_grad=False
            
            assert (0 <= cfg.RESNET.FIXED_BLOCKS < 4)
            if cfg.RESNET.FIXED_BLOCKS >= 3:
                for p in self.RCNN_base[6].parameters(): p.requires_grad=False
            if cfg.RESNET.FIXED_BLOCKS >= 2:
                for p in self.RCNN_base[5].parameters(): p.requires_grad=False
            if cfg.RESNET.FIXED_BLOCKS >= 1:
                for p in self.RCNN_base[4].parameters(): p.requires_grad=False
            
            def set_bn_fix(m):
                classname = m.__class__.__name__
                if classname.find('BatchNorm') != -1:
                    for p in m.parameters(): p.requires_grad=False
            self.RCNN_base.apply(set_bn_fix)

        # define rpn
        self.RCNN_rpn = _RPN(self.dout_base_model)

        #init weight
        def normal_init(m, mean, stddev, truncated=False):
            """
            weight initalizer: truncated normal and random normal.
            """
            # x is a parameter
            if truncated:
                m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean) # not a perfect approximation
            else:
                m.weight.data.normal_(mean, stddev)
                m.bias.data.zero_()

        normal_init(self.RCNN_rpn.RPN_Conv, 0, 0.01, cfg.TRAIN.TRUNCATED)
        normal_init(self.RCNN_rpn.RPN_cls_score, 0, 0.01, cfg.TRAIN.TRUNCATED)
        normal_init(self.RCNN_rpn.RPN_bbox_pred, 0, 0.01, cfg.TRAIN.TRUNCATED)
    def __init__(self,
                 classes,
                 class_agnostic,
                 pretrained=False,
                 base_model='vgg16'):
        super(_detector, self).__init__()
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = class_agnostic
        self.pretrained = pretrained

        #processing of roi get from the rpn
        self.RCNN_proposal_target = _ProposalTargetLayer(self.n_classes)

        # roi pool, roi align or roicrop
        # 1.0/16.0 is because size of input img is 16 times larger than feature map
        self.RCNN_roi_pool = _RoIPooling(cfg.POOLING_SIZE, cfg.POOLING_SIZE,
                                         1.0 / 16.0)
        self.RCNN_roi_align = RoIAlignAvg(cfg.POOLING_SIZE, cfg.POOLING_SIZE,
                                          1.0 / 16.0)
        self.grid_size = cfg.POOLING_SIZE * 2 if cfg.CROP_RESIZE_WITH_MAX_POOL else cfg.POOLING_SIZE
        self.RCNN_roi_crop = _RoICrop()
        self.base_model = base_model

        # top
        if base_model == 'vgg16':
            self.model_path = 'data/pretrained_model/vgg16_caffe.pth'
            vgg = models.vgg16()
            if self.pretrained:
                state_dict = torch.load(self.model_path)
                vgg.load_state_dict({
                    k: v
                    for k, v in state_dict.items() if k in vgg.state_dict()
                })
            self.RCNN_top = nn.Sequential(
                *list(vgg.classifier._modules.values())[:-1])
            # not using the last maxpool layer
            self.RCNN_cls_score = nn.Linear(4096, self.n_classes)

            if self.class_agnostic:
                self.RCNN_bbox_pred = nn.Linear(4096, 4)
            else:
                self.RCNN_bbox_pred = nn.Linear(4096, 4 * self.n_classes)
        elif base_model == 'resnet101' or base_model == 'resnet50' or base_model == 'resnet152':
            if base_model == 'resnet101':
                self.model_path = 'data/pretrained_model/resnet101_caffe.pth'
                resnet = resnet101()
            elif base_model == 'resnet50':
                self.model_path = 'data/pretrained_model/resnet50_caffe.pth'
                resnet = resnet50()
            elif base_model == 'resnet152':
                self.model_path = 'data/pretrained_model/resnet152_caffe.pth'
                resnet = resnet152()
            if self.pretrained:
                print("Detector: Loading pretrained weights from %s" %
                      (self.model_path))
                state_dict = torch.load(self.model_path)
                resnet.load_state_dict({
                    k: v
                    for k, v in state_dict.items() if k in resnet.state_dict()
                })
            self.RCNN_top = nn.Sequential(resnet.layer4)
            self.RCNN_cls_score = nn.Linear(2048, self.n_classes)
            if self.class_agnostic:
                self.RCNN_bbox_pred = nn.Linear(2048, 4)
            else:
                self.RCNN_bbox_pred = nn.Linear(2048, 4 * self.n_classes)

            def set_bn_fix(m):
                classname = m.__class__.__name__
                if classname.find('BatchNorm') != -1:
                    for p in m.parameters():
                        p.requires_grad = False

            self.RCNN_top.apply(set_bn_fix)
        else:
            print("no support for other CNN model")
            exit()