示例#1
0
    def may_load_ckpt(self, load_model=False, load_optimizer=False, load_lr_scheduler=False, strict=True):
        """
        :param load_model: determined if the model needs to be loaded or not.
        :param load_optimizer: determined if the optimizer needs to be loaded or not.
        :param load_lr_scheduler: determined if the lr_scheduler needs to be loaded or not.
        :param strict:
        :return:

        This function is for test part.
        """
        ckpt_file = self.cfg.log.ckpt_file
        assert osp.exists(ckpt_file), "ckpt_file {} does not exist!".format(ckpt_file)
        assert osp.isfile(ckpt_file), "ckpt_file {} is not file!".format(ckpt_file)
        ckpt = torch.load(ckpt_file, map_location=(lambda storage, loc: storage))

        load_ckpt = {}
        if load_model:
            load_ckpt['model'] = self.model
        if load_optimizer:
            load_ckpt['optimizer'] = self.optimizer
        if load_lr_scheduler:
            load_ckpt['lr_scheduler'] = self.lr_scheduler

        for name, item in load_ckpt.items():
            if item is not None:
                # Only nn.Module.load_state_dict has this keyword argument
                if not isinstance(item, torch.nn.Module) or strict:
                    item.load_state_dict(ckpt['state_dicts'][name])
                else:
                    load_state_dict(item, ckpt['state_dicts'][name])

        load_ckpt_str = ', '.join(load_ckpt.keys())
        msg = '=> Loaded [{}] from {}, epoch {}, score:\n{}'.format(load_ckpt_str, ckpt_file, ckpt['epoch'], ckpt['score'])
        print(msg)
        return ckpt['epoch'], ckpt['score']
示例#2
0
    def may_load_ckpt(self,
                      load_model=False,
                      load_optimizer=False,
                      load_lr_scheduler=False,
                      strict=True):
        """
        :param load_model: determined if the model needs to be loaded or not.
        :param load_optimizer: determined if the optimizer needs to be loaded or not.
        :param load_lr_scheduler: determined if the lr_scheduler needs to be loaded or not.
        :param strict:
        :return:

        This function is for test part.
        """
        exp_dir = self.cfg.log.exp_dir  # D:/weights_results/Pyramidal_ReID/pre-trained
        # resume from the resume_test_epoch
        if cfg.optim.resume_from is 'pretrained':
            state_dict = torch.load(
                osp.join(
                    exp_dir, self.pretrained_loaded_model_dict[
                        cfg.dataset.test.names[0]]))
            model_dict = state_dict['state_dicts'][0]
            optimizer_dict = state_dict['state_dicts'][1]
            self.modify_model_modules_name(old_model_dict=model_dict)
            self.optimizer = optimizer_creation(cfg, self.model)
            optimizer_dict['param_groups'] = self.optimizer_load_state_dict(
                optimizer_dict)
            self.optimizer.load_state_dict(optimizer_dict)
            self.save_ckpt = {'model': self.model, 'optimizer': self.optimizer}
            return self.resume_epoch, None
        elif cfg.optim.resume_from is 'whole':
            ckpt_file = self.cfg.log.ckpt_file
            assert osp.exists(
                ckpt_file), "ckpt_file {} does not exist!".format(ckpt_file)
            assert osp.isfile(ckpt_file), "ckpt_file {} is not file!".format(
                ckpt_file)
            ckpt = torch.load(ckpt_file,
                              map_location=(lambda storage, loc: storage))

            load_ckpt = {}
            if load_model:
                load_ckpt['model'] = self.model
            if load_optimizer:
                load_ckpt['optimizer'] = self.optimizer
            if load_lr_scheduler:
                load_ckpt['lr_scheduler'] = self.lr_scheduler

            for name, item in load_ckpt.items():
                if item is not None:
                    # Only nn.Module.load_state_dict has this keyword argument
                    if not isinstance(item, torch.nn.Module) or strict:
                        item.load_state_dict(ckpt['state_dicts'][name])
                    else:
                        load_state_dict(item, ckpt['state_dicts'][name])

            load_ckpt_str = ', '.join(load_ckpt.keys())
            msg = '=> Loaded [{}] from {}, epoch {}, score:\n{}'.format(
                load_ckpt_str, ckpt_file, ckpt['epoch'], ckpt['score'])
            print(msg)
            return ckpt['epoch'], ckpt['score']
示例#3
0
    def __init__(self, cfg):
        super(MGNBackbone, self).__init__()
        resnet = create_backbone(cfg)

        self.backbone = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3[0],
        )

        res_conv4 = nn.Sequential(*resnet.layer3[1:])

        res_g_conv5 = resnet.layer4

        res_p_conv5 = nn.Sequential(
            Bottleneck(1024,
                       512,
                       downsample=nn.Sequential(
                           nn.Conv2d(1024, 2048, 1, bias=False),
                           nn.BatchNorm2d(2048))), Bottleneck(2048, 512),
            Bottleneck(2048, 512))
        # res_p_conv5.load_state_dict(resnet.layer4.state_dict())
        load_state_dict(res_p_conv5, resnet.layer4.state_dict())

        self.p0 = nn.Sequential(copy.deepcopy(res_conv4),
                                copy.deepcopy(res_g_conv5))
        self.p1 = nn.Sequential(copy.deepcopy(res_conv4),
                                copy.deepcopy(res_p_conv5))
        self.p2 = nn.Sequential(copy.deepcopy(res_conv4),
                                copy.deepcopy(res_p_conv5))
示例#4
0
def get_resnet(cfg):
    """download the ResNet model and return."""
    # Only ResNet layer4 use cfg, which is cfg.last_conv_stride
    model = ResNet(arch_dict[cfg.model.backbone.name].block,
                   arch_dict[cfg.model.backbone.name].layers, cfg)
    # Determine the ResNet model if to be pre-trained or not.
    if cfg.model.backbone.pretrained:
        state_dict = model_zoo.load_url(
            url=model_urls[cfg.model.backbone.name],
            model_dir=cfg.model.backbone.pretrained_model_dir)
        load_state_dict(model, state_dict)
        # model_path = osp.abspath(osp.join(cfg.pretrained_model_dir, osp.basename(model_urls[cfg.name])))
        print('=> Loaded ImageNet Model: {}'.format(
            osp.join(cfg.model.backbone.pretrained_model_dir,
                     osp.basename(model_urls[cfg.model.backbone.name]))))
    return model
示例#5
0
 def load_state_dict(self, *args, **kwargs):
     # return self.module.load_state_dict(*args, **kwargs)
     return load_state_dict(self.module, *args, **kwargs)
示例#6
0
    def may_load_ckpt(self,
                      load_model=False,
                      load_optimizer=False,
                      load_lr_scheduler=False,
                      strict=True):
        exp_dir = self.cfg.log.exp_dir  # D:/weights_results/HOReID/pre-trained
        resume_epoch = self.cfg.optim.resume_epoch
        # resume from the resume_test_epoch
        if cfg.optim.resume_from is 'pretrained':
            self.model.model.encoder.load_state_dict(
                torch.load(
                    osp.join(exp_dir, 'encoder_{}.pkl'.format(resume_epoch))))
            self.model.model.bnclassifiers.load_state_dict(
                torch.load(
                    osp.join(exp_dir,
                             'bnclassifiers_{}.pkl'.format(resume_epoch))))
            self.model.model.bnclassifiers2.load_state_dict(
                torch.load(
                    osp.join(exp_dir,
                             'bnclassifiers2_{}.pkl'.format(resume_epoch))))
            self.model.model.graph_conv_net.gcn.load_state_dict(
                torch.load(osp.join(exp_dir,
                                    'gcn_{}.pkl'.format(resume_epoch))))
            self.model.model.graph_matching_net.load_state_dict(
                torch.load(
                    osp.join(exp_dir, 'gmnet_{}.pkl'.format(resume_epoch))))
            self.model.model.verificator.load_state_dict(
                torch.load(
                    osp.join(exp_dir,
                             'verificator_{}.pkl'.format(resume_epoch))))
            return resume_epoch, None
        elif cfg.optim.resume_from is 'whole':
            ckpt_file = self.cfg.log.ckpt_file
            assert osp.exists(
                ckpt_file), "ckpt_file {} does not exist!".format(ckpt_file)
            assert osp.isfile(ckpt_file), "ckpt_file {} is not file!".format(
                ckpt_file)
            ckpt = torch.load(ckpt_file,
                              map_location=(lambda storage, loc: storage))

            load_ckpt = {}
            if load_model:
                load_ckpt['model'] = self.model
            if load_optimizer:
                load_ckpt['optimizer'] = self.optimizer
            if load_lr_scheduler:
                load_ckpt['lr_scheduler'] = self.lr_scheduler

            for name, item in load_ckpt.items():
                if item is not None:
                    # Only nn.Module.load_state_dict has this keyword argument
                    if not isinstance(item, torch.nn.Module) or strict:
                        item.load_state_dict(ckpt['state_dicts'][name])
                    else:
                        load_state_dict(item, ckpt['state_dicts'][name])

            load_ckpt_str = ', '.join(load_ckpt.keys())
            msg = '=> Loaded [{}] from {}, epoch {}, score:\n{}'.format(
                load_ckpt_str, ckpt_file, ckpt['epoch'], ckpt['score'])
            print(msg)
            return ckpt['epoch'], ckpt['score']