Пример #1
0
    def _init_modules(self):
        resnet = resnet50()
        if self.pretrained == True:
            print("Loading pretrained weights from %s" %(self.model_path))
            state_dict = torch.load(self.model_path)
            resnet.load_state_dict({k:v for k,v in state_dict.items() if k in resnet.state_dict()})

        # Build resnet. (base -> top -> head)
        self.RCNN_base = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
            resnet.maxpool,resnet.layer1,resnet.layer2,resnet.layer3)
        self.RCNN_top = nn.Sequential(resnet.layer4)  # 1024 -> 2048
        # build rcnn head
        self.RCNN_bbox_pred = nn.Linear(2048, 4)

        # Fix blocks 
        for p in self.RCNN_base[0].parameters(): p.requires_grad=False
        for p in self.RCNN_base[1].parameters(): p.requires_grad=False

        assert (0 <= cfg.RESNET.FIXED_BLOCKS < 4)
        if cfg.RESNET.FIXED_BLOCKS >= 3:
            for p in self.RCNN_base[6].parameters(): p.requires_grad=False
        if cfg.RESNET.FIXED_BLOCKS >= 2:
            for p in self.RCNN_base[5].parameters(): p.requires_grad=False
        if cfg.RESNET.FIXED_BLOCKS >= 1:
            for p in self.RCNN_base[4].parameters(): p.requires_grad=False

        def set_bn_fix(m):
            classname = m.__class__.__name__
            if classname.find('BatchNorm') != -1:
                for p in m.parameters(): p.requires_grad=False

        self.RCNN_base.apply(set_bn_fix)
        self.RCNN_top.apply(set_bn_fix)
Пример #2
0
    def _init_modules(self):

        resnet = resnet101()

        if self.layers == 50:
            resnet = resnet50()
        elif self.layers == 152:
            resnet = resnet152()
        if self.pretrained == True and (self.layers in [50, 101, 152]):
            print("Loading pretrained weights from %s" % (self.model_path))
            state_dict = torch.load(self.model_path)
            resnet.load_state_dict(
                {
                    k: v
                    for k, v in state_dict.items() if k in resnet.state_dict()
                },
                strict=False)
            in_chan = 256
        elif self.layers in [50, 101, 152]:
            print("Loading pretrained weights from pytorch")
            resnet = resnet50(pretrained=True)
            in_chan = 256
        else:
            if self.layers == 18:
                resnet = resnet18(pretrained=True)
            if self.layers == 34:
                resnet = resnet34(pretrained=True)
            in_chan = 64

        # Build resnet.
        self.RCNN_base1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
                                        resnet.maxpool, resnet.layer1)
        self.RCNN_base2 = nn.Sequential(resnet.layer2)
        self.RCNN_base3 = nn.Sequential(resnet.layer3)

        if self.layers in [50, 101, 152]:
            self.netD_pixel = netD_pixel(context=self.lc)
            self.netD = netD(context=self.gc)
            self.netD_mid = netD_mid(context=self.gc)
            feat_d = 2048
        else:
            self.netD_pixel = netD_pixel(context=self.lc, in_chan=in_chan)
            self.netD = netD(context=self.gc, in_chan=256)
            self.netD_mid = netD_mid(context=self.gc, in_chan=in_chan * 2)
            feat_d = 512

        self.RCNN_top = nn.Sequential(resnet.layer4)

        feat_d2 = 384
        feat_d3 = 1024

        self.RandomLayer = RandomLayer([feat_d, feat_d2], feat_d3)
        self.RandomLayer.cuda()

        self.netD_da = netD_da(feat_d3)

        self.stu_feature_adap = nn.Sequential(
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1), nn.ReLU())

        self.RCNN_cls_score = nn.Linear(feat_d + feat_d2, self.n_classes)
        if self.class_agnostic:
            self.RCNN_bbox_pred = nn.Linear(feat_d + feat_d2, 4)
        else:
            self.RCNN_bbox_pred = nn.Linear(feat_d + feat_d2,
                                            4 * self.n_classes)

        # Fix blocks
        for p in self.RCNN_base1[0].parameters():
            p.requires_grad = False
        for p in self.RCNN_base1[1].parameters():
            p.requires_grad = False

        # assert (0 <= cfg.RESNET.FIXED_BLOCKS < 4)
        # if cfg.RESNET.FIXED_BLOCKS >= 3:
        #   for p in self.RCNN_base1[6].parameters(): p.requires_grad=False
        # if cfg.RESNET.FIXED_BLOCKS >= 2:
        #   for p in self.RCNN_base1[5].parameters(): p.requires_grad=False
        #if cfg.RESNET.FIXED_BLOCKS >= 1:
        #  for p in self.RCNN_base1[4].parameters(): p.requires_grad=False

        def set_bn_fix(m):
            classname = m.__class__.__name__
            if classname.find('BatchNorm') != -1:
                for p in m.parameters():
                    p.requires_grad = False

        self.RCNN_base1.apply(set_bn_fix)
        self.RCNN_base2.apply(set_bn_fix)
        self.RCNN_base3.apply(set_bn_fix)
        self.RCNN_top.apply(set_bn_fix)