Exemplo n.º 1
0
    def _init_modules(self):
        if cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS:
            resnet_utils.load_pretrained_imagenet_weights(self)

        if cfg.TRAIN.FREEZE_CONV_BODY:
            for p in self.Conv_Body.parameters():
                p.requires_grad = False
    def _init_modules(self):
        if cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS:
            resnet_utils.load_pretrained_imagenet_weights(self)
            # Check if shared weights are equaled
            if cfg.MODEL.MASK_ON and getattr(self.Mask_Head, 'SHARE_RES5',
                                             False):
                assert compare_state_dict(self.Mask_Head.res5.state_dict(),
                                          self.Box_Head.res5.state_dict())
            if cfg.MODEL.KEYPOINTS_ON and getattr(self.Keypoint_Head,
                                                  'SHARE_RES5', False):
                assert compare_state_dict(self.Keypoint_Head.res5.state_dict(),
                                          self.Box_Head.res5.state_dict())

        if cfg.TRAIN.FREEZE_CONV_BODY:
            for p in self.Conv_Body.parameters():
                p.requires_grad = False

        if cfg.TRAIN.FREEZE_RPN:
            for p in self.RPN.parameters():
                p.requires_grad = False

        if cfg.TRAIN.FREEZE_FPN:
            for p in self.Box_Head.parameters():
                p.requires_grad = False
            for p in self.Box_Outs.parameters():
                p.requires_grad = False
    def _init_modules(self):
        if cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS != '':  # or cfg.MODEL.USE_SE_LOSS:
            logger.info("Loading pretrained weights from %s", cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS)
            resnet_utils.load_pretrained_imagenet_weights(self)
        # Check if shared weights are equaled
        if cfg.MODEL.MASK_ON and getattr(self.Mask_Head, 'SHARE_RES5', False):
            assert self.Mask_Head.res5.state_dict() == self.Box_Head.res5.state_dict()
        if cfg.MODEL.KEYPOINTS_ON and getattr(self.Keypoint_Head, 'SHARE_RES5', False):
            assert self.Keypoint_Head.res5.state_dict() == self.Box_Head.res5.state_dict()
        
        # load detectron pretrained weights for resnet
        if cfg.RESNETS.COCO_PRETRAINED_WEIGHTS != '':
            logger.info("loading detectron pretrained weights from %s", cfg.RESNETS.COCO_PRETRAINED_WEIGHTS)
            load_detectron_weight(self, cfg.RESNETS.COCO_PRETRAINED_WEIGHTS, ('cls_score', 'bbox_pred'))

        if cfg.VGG16.COCO_PRETRAINED_WEIGHTS != '':
            logger.info("loading pretrained weights from %s", cfg.VGG16.COCO_PRETRAINED_WEIGHTS)
            checkpoint = torch.load(cfg.VGG16.COCO_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage)
            # not using the last softmax layers
            del checkpoint['model']['Box_Outs.cls_score.weight']
            del checkpoint['model']['Box_Outs.cls_score.bias']
            del checkpoint['model']['Box_Outs.bbox_pred.weight']
            del checkpoint['model']['Box_Outs.bbox_pred.bias']
            net_utils.load_ckpt(self, checkpoint['model'])
            
        if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '':
            logger.info("loading trained and to be finetuned weights from %s", cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS)
            checkpoint = torch.load(cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS, map_location=lambda storage, loc: storage)
            net_utils.load_ckpt(self, checkpoint['model'])

        if cfg.TRAIN.FREEZE_CONV_BODY:
            for p in self.Conv_Body.parameters():
                p.requires_grad = False
Exemplo n.º 4
0
    def _init_modules(self):
        # VGG16 imagenet pretrained model is initialized in VGG16.py
        if cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS != '':
            logger.info("Loading pretrained weights from %s",
                        cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS)
            resnet_utils.load_pretrained_imagenet_weights(self)

        if cfg.RESNETS.VRD_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.VRD_PRETRAINED_WEIGHTS)
        if cfg.VGG16.VRD_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.VRD_PRETRAINED_WEIGHTS)

        if cfg.RESNETS.VG_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.VG_PRETRAINED_WEIGHTS)
        if cfg.VGG16.VG_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.VG_PRETRAINED_WEIGHTS)

        if cfg.TRAIN.FREEZE_CONV_BODY:
            for p in self.Conv_Body.parameters():
                p.requires_grad = False

        if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '' or \
            cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '':
            if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            # not using the last softmax layers
            del checkpoint['model']['Box_Outs.cls_score.weight']
            del checkpoint['model']['Box_Outs.cls_score.bias']
            del checkpoint['model']['Box_Outs.bbox_pred.weight']
            del checkpoint['model']['Box_Outs.bbox_pred.bias']
            net_utils.load_ckpt(self.Prd_RCNN, checkpoint['model'])
            if cfg.TRAIN.FREEZE_PRD_CONV_BODY:
                for p in self.Prd_RCNN.Conv_Body.parameters():
                    p.requires_grad = False
            if cfg.TRAIN.FREEZE_PRD_BOX_HEAD:
                for p in self.Prd_RCNN.Box_Head.parameters():
                    p.requires_grad = False
Exemplo n.º 5
0
    def _init_modules(self):
        if cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS:
            print ('loading weights for ResNet')
            resnet_utils.load_pretrained_imagenet_weights(self)
            print ('loading weights is done')

        if cfg.TRAIN.FREEZE_CONV_BODY:
            print ('freeze train conv_body')
            for p in self.Conv_Body.parameters():
                p.requires_grad = False
    def _init_modules(self):
        if cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS:
            resnet_utils.load_pretrained_imagenet_weights(self)
            # Check if shared weights are equaled
            if cfg.MODEL.MASK_ON and getattr(self.Mask_Head, 'SHARE_RES5', False):
                assert compare_state_dict(self.Mask_Head.res5.state_dict(), self.Box_Head.res5.state_dict())
            if cfg.MODEL.KEYPOINTS_ON and getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                assert compare_state_dict(self.Keypoint_Head.res5.state_dict(), self.Box_Head.res5.state_dict())

        if cfg.TRAIN.FREEZE_CONV_BODY:
            for p in self.Conv_Body.parameters():
                p.requires_grad = False
Exemplo n.º 7
0
    def _init_modules(self):
        # print(self.state_dict().keys())
        if cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS:
            resnet_utils.load_pretrained_imagenet_weights(self)
            # Check if shared weights are equaled
            if cfg.MODEL.MASK_ON and getattr(self.Amodal_Mask_Head, 'SHARE_RES5', False):
                assert self.Amodal_Mask_Head.res5.state_dict() == self.Box_Head.res5.state_dict()
            if cfg.MODEL.INMODAL_ON and getattr(self.Inmodal_Mask_Head, 'SHARE_RES5', False):
                assert self.Inmodal_Mask_Head.res5.state_dict() == self.Box_Head.res5.state_dict()
            if cfg.MODEL.KEYPOINTS_ON and getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                assert self.Keypoint_Head.res5.state_dict() == self.Box_Head.res5.state_dict()

        if cfg.TRAIN.FREEZE_CONV_BODY:
            for p in self.Conv_Body.parameters():
                p.requires_grad = False
    def _init_modules(self):
        if cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS:
            resnet_utils.load_pretrained_imagenet_weights(self)
        if not cfg.REID.FPN:
            print('cfg.REID.FPN:{}'.format(cfg.REID.FPN))
            for p in self.Conv_Body.conv_top.parameters():
                p.requires_grad = False
            for p in self.Conv_Body.posthoc_modules.parameters():
                p.requires_grad = False
            for p in self.Conv_Body.topdown_lateral_modules.parameters():
                p.requires_grad = False
        if cfg.TRAIN.FREEZE_CONV_BODY:
            for p in self.Conv_Body.conv_body.parameters():
                p.requires_grad = False

        if cfg.REID.REGULARIZED_POOLING:
            # print('Not implemented!')
            self.res5_1 = copy.deepcopy(self.Conv_Body.conv_body.res5)
Exemplo n.º 9
0
 def _init_modules(self):
     # VGG16 imagenet pretrained model is initialized in VGG16.py
     if cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS != '':
         logger.info("Loading pretrained weights from %s", cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS)
         resnet_utils.load_pretrained_imagenet_weights(self)
         for p in self.Conv_Body.parameters():
             p.requires_grad = False
             
     if cfg.RESNETS.VRD_PRETRAINED_WEIGHTS != '':
         self.load_detector_weights(cfg.RESNETS.VRD_PRETRAINED_WEIGHTS)
     if cfg.VGG16.VRD_PRETRAINED_WEIGHTS != '':
         self.load_detector_weights(cfg.VGG16.VRD_PRETRAINED_WEIGHTS)
         
     if cfg.RESNETS.VG_PRETRAINED_WEIGHTS != '':
         self.load_detector_weights(cfg.RESNETS.VG_PRETRAINED_WEIGHTS)
     if cfg.VGG16.VG_PRETRAINED_WEIGHTS != '':
         self.load_detector_weights(cfg.VGG16.VG_PRETRAINED_WEIGHTS)
         
     if cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS != '':
         self.load_detector_weights(cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS)
     if cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS != '':
         self.load_detector_weights(cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS)
    def _init_modules(self):
        # VGG16 imagenet pretrained model is initialized in VGG16.py
        if cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS != '':
            logger.info("Loading pretrained weights from %s",
                        cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS)
            resnet_utils.load_pretrained_imagenet_weights(self)
            for p in self.Conv_Body.parameters():
                p.requires_grad = False

        if cfg.RESNETS.VRD_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.VRD_PRETRAINED_WEIGHTS)
        if cfg.VGG16.VRD_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.VRD_PRETRAINED_WEIGHTS)

        if cfg.RESNETS.VG_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.VG_PRETRAINED_WEIGHTS)
        if cfg.VGG16.VG_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.VG_PRETRAINED_WEIGHTS)

        if cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS)
        if cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS)

        if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '' or \
            cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '' or \
            cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS != '':
            if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            # not using the last softmax layers
            del checkpoint['model']['Box_Outs.cls_score.weight']
            del checkpoint['model']['Box_Outs.cls_score.bias']
            del checkpoint['model']['Box_Outs.bbox_pred.weight']
            del checkpoint['model']['Box_Outs.bbox_pred.bias']
            net_utils_rel.load_ckpt_rel(self.Prd_RCNN, checkpoint['model'])
            if cfg.TRAIN.FREEZE_PRD_CONV_BODY:
                for p in self.Prd_RCNN.Conv_Body.parameters():
                    p.requires_grad = False
            if cfg.TRAIN.FREEZE_PRD_BOX_HEAD:
                for p in self.Prd_RCNN.Box_Head.parameters():
                    p.requires_grad = False

        if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '' or cfg.VGG16.TO_BE_FINETUNED_WEIGHTS != '':
            if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '':
                logger.info(
                    "loading trained and to be finetuned weights from %s",
                    cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.TO_BE_FINETUNED_WEIGHTS != '':
                logger.info(
                    "loading trained and to be finetuned weights from %s",
                    cfg.VGG16.TO_BE_FINETUNED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.TO_BE_FINETUNED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            net_utils_rel.load_ckpt_rel(self, checkpoint['model'])
            for p in self.Conv_Body.parameters():
                p.requires_grad = False
            for p in self.RPN.parameters():
                p.requires_grad = False
            if not cfg.MODEL.UNFREEZE_DET:
                for p in self.Box_Head.parameters():
                    p.requires_grad = False
                for p in self.Box_Outs.parameters():
                    p.requires_grad = False

        if cfg.RESNETS.REL_PRETRAINED_WEIGHTS != '':
            logger.info("loading rel pretrained weights from %s",
                        cfg.RESNETS.REL_PRETRAINED_WEIGHTS)
            checkpoint = torch.load(cfg.RESNETS.REL_PRETRAINED_WEIGHTS,
                                    map_location=lambda storage, loc: storage)
            prd_rcnn_state_dict = {}
            reldn_state_dict = {}
            for name in checkpoint['model']:
                if name.find('Prd_RCNN') >= 0:
                    prd_rcnn_state_dict[name] = checkpoint['model'][name]
                if name.find('RelDN') >= 0:
                    reldn_state_dict[name] = checkpoint['model'][name]
            net_utils_rel.load_ckpt_rel(self.Prd_RCNN, prd_rcnn_state_dict)
            if cfg.TRAIN.FREEZE_PRD_CONV_BODY:
                for p in self.Prd_RCNN.Conv_Body.parameters():
                    p.requires_grad = False
            if cfg.TRAIN.FREEZE_PRD_BOX_HEAD:
                for p in self.Prd_RCNN.Box_Head.parameters():
                    p.requires_grad = False
            del reldn_state_dict['RelDN.prd_cls_scores.weight']
            del reldn_state_dict['RelDN.prd_cls_scores.bias']
            if 'RelDN.prd_sbj_scores.weight' in reldn_state_dict:
                del reldn_state_dict['RelDN.prd_sbj_scores.weight']
            if 'RelDN.prd_sbj_scores.bias' in reldn_state_dict:
                del reldn_state_dict['RelDN.prd_sbj_scores.bias']
            if 'RelDN.prd_obj_scores.weight' in reldn_state_dict:
                del reldn_state_dict['RelDN.prd_obj_scores.weight']
            if 'RelDN.prd_obj_scores.bias' in reldn_state_dict:
                del reldn_state_dict['RelDN.prd_obj_scores.bias']
            if 'RelDN.spt_cls_scores.weight' in reldn_state_dict:
                del reldn_state_dict['RelDN.spt_cls_scores.weight']
            if 'RelDN.spt_cls_scores.bias' in reldn_state_dict:
                del reldn_state_dict['RelDN.spt_cls_scores.bias']
            net_utils_rel.load_ckpt_rel(self.RelDN, reldn_state_dict)
Exemplo n.º 11
0
    def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=''):
        if cfg.SEM.FREEZE_BN:
            print("Using AffineChannel2d as SynchronizedBatchNorm2d")
            from lib.nn import AffineChannel2d
            global SynchronizedBatchNorm2d
            SynchronizedBatchNorm2d = AffineChannel2d
        pretrained = True if len(weights) == 0 else False
        if arch == 'resnet18':
            orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
            net_encoder = Resnet(orig_resnet)
        elif arch == 'resnet18_dilated8':
            orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
            net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
        elif arch == 'resnet18_dilated16':
            orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
            net_encoder = ResnetDilated(orig_resnet, dilate_scale=16)
        elif arch == 'resnet34':
            raise NotImplementedError
            orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
            net_encoder = Resnet(orig_resnet)
        elif arch == 'resnet34_dilated8':
            raise NotImplementedError
            orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
            net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
        elif arch == 'resnet34_dilated16':
            raise NotImplementedError
            orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
            net_encoder = ResnetDilated(orig_resnet, dilate_scale=16)
        elif arch == 'resnet50':
            orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
            net_encoder = Resnet(orig_resnet)

        elif arch == 'ResnetX_dilated8' or arch == 'ResnetX_dilated16':
            #orig_resnet = ResNet.__dict__['resnet101'](pretrained=False)
            net_encoder = eval(cfg.MODEL.CONV_BODY)()
            print('loading weights for ResNet')
            resnet_utils.load_pretrained_imagenet_weights(net_encoder)
            print('loading weights is done')
            if cfg.TRAIN.FREEZE_CONV_BODY:
                print('freeze train conv_body')
                for p in net_encoder.parameters():
                    p.requires_grad = False

            #print ('loading pretrained model for ResNet')
            #pretrained=torch.load(cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage)
            #with open(cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS, 'rb') as f:
            #    pretrained = pickle.load(f, encoding='latin1')
            #print (pretrained['blobs'].keys())
            #orig_resnet.load_state_dict(pretrained['model'],strict=True)
            #print (pretrained.keys())
            #net_encoder.load_state_dict(pretrained,strict=True)
            #print ('loading pretrained is done')
            #net_encoder = ResnetDilated(orig_resnet,dilate_scale=8)
            #net_encoder = ResnetDilated(orig_resnet,dilate_scale=8)
            ##pretrained=pretrained['model']

            #print ('loading pretrained is done')
            #from  lib.nn import AffineChannel2d
            #global SynchronizedBatchNorm2d
            #SynchronizedBatchNorm2d = AffineChannel2d

        elif arch == 'resnet50_dilated8':
            orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
            net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)

        elif arch == 'resnet50_dilated16':
            orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
            net_encoder = ResnetDilated(orig_resnet, dilate_scale=16)
        elif arch == 'resnet101':
            orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
            net_encoder = Resnet(orig_resnet)
        elif arch == 'resnet101_dilated8':
            orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
            net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
        elif arch == 'resnet101_dilated16':
            orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
            net_encoder = ResnetDilated(orig_resnet, dilate_scale=16)

        elif arch == 'resnet152':
            orig_resnet = resnet.__dict__['resnet152'](pretrained=pretrained)
            net_encoder = Resnet152(orig_resnet)
        elif arch == 'resnet152_dilated8':
            orig_resnet = resnet.__dict__['resnet152'](pretrained=pretrained)
            net_encoder = Resnet152Dilated(orig_resnet, dilate_scale=8)
        elif arch == 'resnet152_dilated16':
            orig_resnet = resnet.__dict__['resnet152'](pretrained=pretrained)
            net_encoder = Resnet152Dilated(orig_resnet, dilate_scale=16)

        elif arch == 'resnext101':
            orig_resnext = resnext.__dict__['resnext101'](
                pretrained=pretrained)
            net_encoder = Resnet(orig_resnext)  # we can still use class Resnet

        elif arch == 'se_resnet50_dilate8':
            net_encoder = senet.__dict__['se_resnet50_dilate'](
                dilate=8, pretrained=pretrained)
        elif arch == 'se_resnet101_dilate8':
            net_encoder = senet.__dict__['se_resnet101_dilate'](
                dilate=8, pretrained=pretrained)
        elif arch == 'se_resnet152_dilate8':
            net_encoder = senet.__dict__['se_resnet152_dilate'](
                dilate=8, pretrained=pretrained)
        elif arch == 'se_resnext50_dilate8_32x4d':
            net_encoder = senet.__dict__['se_resnext50_dilate_32x4d'](
                dilate=8, pretrained=pretrained)
        elif arch == 'se_resnext101_dilate8_32x4d':
            net_encoder = senet.__dict__['se_resnext101_dilate_32x4d'](
                dilate=8, pretrained=pretrained)
        elif arch == 'senet154_dilate8':
            net_encoder = senet.__dict__['senet154_dilate'](
                dilate=8, pretrained=pretrained)

        elif arch == 'se_resnet50_dilate16':
            net_encoder = senet.__dict__['se_resnet50_dilate'](
                dilate=16, pretrained=pretrained)
        elif arch == 'se_resnet101_dilate16':
            net_encoder = senet.__dict__['se_resnet101_dilate'](
                dilate=16, pretrained=pretrained)
        elif arch == 'se_resnet152_dilate16':
            net_encoder = senet.__dict__['se_resnet152_dilate'](
                dilate=16, pretrained=pretrained)
        elif arch == 'se_resnext50_dilate16_32x4d':
            net_encoder = senet.__dict__['se_resnext50_dilate_32x4d'](
                dilate=16, pretrained=pretrained)
        elif arch == 'se_resnext101_dilate16_32x4d':
            net_encoder = senet.__dict__['se_resnext101_dilate_32x4d'](
                dilate=16, pretrained=pretrained)
        elif arch == 'senet154_dilate16':
            net_encoder = senet.__dict__['senet154_dilate'](dilate=16,
                                                            pretrained=False)
        elif arch == 'xception':
            net_encoder = build_backbone('xception', 16, BatchNorm)
        else:
            raise Exception('Architecture undefined!')

        # net_encoder.apply(self.weights_init)
        if len(weights) > 0:
            print('Loading weights for net_encoder')
            net_encoder.load_state_dict(torch.load(
                weights, map_location=lambda storage, loc: storage),
                                        strict=False)
        return net_encoder
Exemplo n.º 12
0
 def _init_modules(self):
     resnet_utils.load_pretrained_imagenet_weights(self)
Exemplo n.º 13
0
    def _init_modules(self):
        # VGG16 imagenet pretrained model is initialized in VGG16.py
        if cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS != '':
            logger.info("Loading pretrained weights from %s",
                        cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS)
            resnet_utils.load_pretrained_imagenet_weights(self)
            for p in self.Conv_Body.parameters():
                p.requires_grad = False

        if cfg.RESNETS.VRD_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.VRD_PRETRAINED_WEIGHTS)
        if cfg.VGG16.VRD_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.VRD_PRETRAINED_WEIGHTS)

        if cfg.RESNETS.VG_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.VG_PRETRAINED_WEIGHTS)
        if cfg.VGG16.VG_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.VG_PRETRAINED_WEIGHTS)

        if cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.OI_REL_PRETRAINED_WEIGHTS)
        if cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.OI_REL_PRETRAINED_WEIGHTS)

        if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '' or \
            cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '' or \
            cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS != '':
            if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.OI_REL_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s",
                            cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.OI_REL_PRD_PRETRAINED_WEIGHTS,
                    map_location=lambda storage, loc: storage)

            self.Box_Head_sg.heads[0].weight.data.copy_(
                checkpoint['model']['Box_Head.heads.0.weight'])
            self.Box_Head_sg.heads[0].bias.data.copy_(
                checkpoint['model']['Box_Head.heads.0.bias'])
            self.Box_Head_sg.heads[3].weight.data.copy_(
                checkpoint['model']['Box_Head.heads.3.weight'])
            self.Box_Head_sg.heads[3].bias.data.copy_(
                checkpoint['model']['Box_Head.heads.3.bias'])
            self.Box_Head_prd.heads[0].weight.data.copy_(
                checkpoint['model']['Box_Head.heads.0.weight'])
            self.Box_Head_prd.heads[0].bias.data.copy_(
                checkpoint['model']['Box_Head.heads.0.bias'])
            self.Box_Head_prd.heads[3].weight.data.copy_(
                checkpoint['model']['Box_Head.heads.3.weight'])
            self.Box_Head_prd.heads[3].bias.data.copy_(
                checkpoint['model']['Box_Head.heads.3.bias'])

        if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '' or cfg.VGG16.TO_BE_FINETUNED_WEIGHTS != '':
            if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '':
                logger.info(
                    "loading trained and to be finetuned weights from %s",
                    cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            if cfg.VGG16.TO_BE_FINETUNED_WEIGHTS != '':
                logger.info(
                    "loading trained and to be finetuned weights from %s",
                    cfg.VGG16.TO_BE_FINETUNED_WEIGHTS)
                checkpoint = torch.load(
                    cfg.VGG16.TO_BE_FINETUNED_WEIGHTS,
                    map_location=lambda storage, loc: storage)
            net_utils_rel.load_ckpt_rel(self, checkpoint['model'])
            for p in self.Conv_Body.parameters():
                p.requires_grad = False
            for p in self.RPN.parameters():
                p.requires_grad = False
            if not cfg.MODEL.UNFREEZE_DET:
                for p in self.Box_Head.parameters():
                    p.requires_grad = False
                for p in self.Box_Outs.parameters():
                    p.requires_grad = False