コード例 #1
0
    def may_load_ckpt(self,
                      load_model=False,
                      load_optimizer=False,
                      load_lr_scheduler=False,
                      strict=True):
        """
        :param load_model: determined if the model needs to be loaded or not.
        :param load_optimizer: determined if the optimizer needs to be loaded or not.
        :param load_lr_scheduler: determined if the lr_scheduler needs to be loaded or not.
        :param strict:
        :return:

        This function is for test part.
        """
        exp_dir = self.cfg.log.exp_dir  # D:/weights_results/Pyramidal_ReID/pre-trained
        # resume from the resume_test_epoch
        if cfg.optim.resume_from is 'pretrained':
            state_dict = torch.load(
                osp.join(
                    exp_dir, self.pretrained_loaded_model_dict[
                        cfg.dataset.test.names[0]]))
            model_dict = state_dict['state_dicts'][0]
            optimizer_dict = state_dict['state_dicts'][1]
            self.modify_model_modules_name(old_model_dict=model_dict)
            self.optimizer = optimizer_creation(cfg, self.model)
            optimizer_dict['param_groups'] = self.optimizer_load_state_dict(
                optimizer_dict)
            self.optimizer.load_state_dict(optimizer_dict)
            self.save_ckpt = {'model': self.model, 'optimizer': self.optimizer}
            return self.resume_epoch, None
        elif cfg.optim.resume_from is 'whole':
            ckpt_file = self.cfg.log.ckpt_file
            assert osp.exists(
                ckpt_file), "ckpt_file {} does not exist!".format(ckpt_file)
            assert osp.isfile(ckpt_file), "ckpt_file {} is not file!".format(
                ckpt_file)
            ckpt = torch.load(ckpt_file,
                              map_location=(lambda storage, loc: storage))

            load_ckpt = {}
            if load_model:
                load_ckpt['model'] = self.model
            if load_optimizer:
                load_ckpt['optimizer'] = self.optimizer
            if load_lr_scheduler:
                load_ckpt['lr_scheduler'] = self.lr_scheduler

            for name, item in load_ckpt.items():
                if item is not None:
                    # Only nn.Module.load_state_dict has this keyword argument
                    if not isinstance(item, torch.nn.Module) or strict:
                        item.load_state_dict(ckpt['state_dicts'][name])
                    else:
                        load_state_dict(item, ckpt['state_dicts'][name])

            load_ckpt_str = ', '.join(load_ckpt.keys())
            msg = '=> Loaded [{}] from {}, epoch {}, score:\n{}'.format(
                load_ckpt_str, ckpt_file, ckpt['epoch'], ckpt['score'])
            print(msg)
            return ckpt['epoch'], ckpt['score']
コード例 #2
0
    def __init__(self, cfg):
        self.cfg = cfg

        self.current_ep = 0
        # Init the test loader
        self.test_loader = dataloader_creation(self.cfg,
                                               mode='test',
                                               domain='source',
                                               train_type='Supervised')
        if self.cfg.only_test is False:
            # TensorBoard object must not be in EasyDict()!!!!
            # cfg.log.tb_writer should be error!!!!
            if self.cfg.log.use_tensorboard:
                from tensorboardX import SummaryWriter
                self.tb_writer = SummaryWriter(
                    log_dir=osp.join(self.cfg.log.exp_dir, 'tensorboard'))
            else:
                self.tb_writer = None

            self.source_train_loader = dataloader_creation(
                self.cfg,
                mode='train',
                domain='source',
                train_type='Supervised')

            self.model = model_creation(self.cfg)
            self.optimizer = optimizer_creation(cfg, self.model)
            self.lr_scheduler = lr_scheduler_creation(cfg, self.optimizer,
                                                      self.source_train_loader)
            self.loss_functions = loss_function_creation(cfg, self.tb_writer)
            self.analyze_functions = analyze_function_creation(
                cfg, self.tb_writer)

            self.epoch_start_time = 0
            self.trial_run_steps = 3 if cfg.optim.trial_run else None

            self.current_step = 0  # will NOT be reset between epochs
            self.steps_per_log = self.cfg.optim.steps_per_log
            self.print_step_log = print_step_log

            self.current_ep = 0
            self.print_ep_log = None  # function
            self.eps_per_log = 1

            self.save_ckpt = {
                'model': self.model,
                'optimizer': self.optimizer,
                'lr_scheduler': self.lr_scheduler
            }

        else:
            # Init the test part
            self.model = model_creation(self.cfg)
            self.current_ep, _ = self.may_load_ckpt()
        if self.cfg.optim.resume is True:
            self.resume()
コード例 #3
0
    def __init__(self, cfg):
        self.cfg = cfg
        # Init the train part
        if self.cfg.only_test is False:
            # TensorBoard object must not be in EasyDict()!!!!
            # cfg.log.tb_writer should be error!!!!
            if self.cfg.log.use_tensorboard:
                from tensorboardX import SummaryWriter
                self.tb_writer = SummaryWriter(
                    log_dir=osp.join(self.cfg.log.exp_dir, 'tensorboard'))
            else:
                self.tb_writer = None

            self.source_train_loader = dataloader_creation(
                self.cfg,
                mode='train',
                domain='source',
                train_type='Supervised')

            self.model = model_creation(self.cfg)
            self.optimizer = optimizer_creation(cfg, self.model)
            self.lr_scheduler = lr_scheduler_creation(cfg, self.optimizer,
                                                      self.source_train_loader)
            self.loss_functions = loss_function_creation(cfg, self.tb_writer)
            self.analyze_functions = None

            self.epoch_start_time = 0
            self.trial_run_steps = 3 if cfg.optim.trial_run else None

            self.current_step = 0  # will NOT be reset between epochs
            self.steps_per_log = self.cfg.optim.steps_per_log
            self.print_step_log = print_step_log

            self.current_ep = 0
            self.print_ep_log = None  # function
            self.eps_per_log = 1

            self.save_ckpt = {
                'model': self.model,
                'optimizer': self.optimizer,
                'lr_scheduler': self.lr_scheduler
            }
        else:
            # Init the test part
            self.model = model_creation(self.cfg)
            self.resume_epoch = self.cfg.optim.resume_epoch  # 112
            self.pretrained_loaded_model_dict = {
                'market1501':
                'ckpt_ep{}_re02_bs64_dropout02_GPU0_mAP0.882439013042_{}.pth'.
                format(self.resume_epoch, cfg.dataset.test.names[0]),
                'duke':
                'ckpt_ep{}_re02_bs64_dropout02_GPU2_mAP0.788985533455_{}.pth'.
                format(self.resume_epoch, cfg.dataset.test.names[0]),
                'cuhk03_np_detected_jpg':
                'ckpt_ep{}_re02_bs64_dropout02_GPU2_mAP0.747726555617_{}.pth'.
                format(self.resume_epoch, cfg.dataset.test.names[0])
            }
            self.current_ep, _ = self.may_load_ckpt(load_model=True,
                                                    strict=False)
        # Init the test loader
        self.test_loader = dataloader_creation(self.cfg,
                                               mode='test',
                                               domain='source',
                                               train_type='Supervised')
        if self.cfg.optim.resume is True:
            self.resume()