Esempio n. 1
0
    def __init__(self, cfg, model, train_dl, val_dl,
                 loss_func, num_query, num_gpus):
        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.loss_func = loss_func
        self.num_query = num_query

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.train_epoch = 1
        self.batch_cnt = 0

        self.logger = logging.getLogger('reid_baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR
        self.device = cfg.MODEL.DEVICE
        self.epochs = cfg.SOLVER.MAX_EPOCHS

        if num_gpus > 1:
            
            # Multi-GPU model without FP16
            self.model = nn.DataParallel(self.model)
            if cfg.SOLVER.SYNCBN:
                # convert to use sync_bn
                self.logger.info('More than one gpu used, convert model to use SyncBN.')
                self.model = convert_model(self.model)
                self.logger.info('Using pytorch SyncBN implementation')
            self.model.cuda()
            
            self.optim = make_optimizer(self.model,opt=self.cfg.SOLVER.OPTIMIZER_NAME,lr=cfg.SOLVER.BASE_LR,weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,cfg.SOLVER.WARMUP_EPOCH, cfg.SOLVER.WARMUP_METHOD)
            
            self.mix_precision = False
            self.logger.info(self.model)
            self.logger.info(self.optim)
            self.logger.info('Trainer Built')
            return

        else:
            # Single GPU model
            self.model.cuda()
    
            self.optim = make_optimizer(self.model,opt=self.cfg.SOLVER.OPTIMIZER_NAME,lr=cfg.SOLVER.BASE_LR,weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,cfg.SOLVER.WARMUP_EPOCH, cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.model)
            self.logger.info(self.optim)
            self.mix_precision = False
            return
Esempio n. 2
0
    device = torch.device("cpu")
    num_gpus = 0
    if cfg.MODEL.DEVICE == 'cuda' and torch.cuda.is_available():
        num_gpus = len(cfg.MODEL.DEVICE_IDS) - 1
        device_ids = cfg.MODEL.DEVICE_IDS.strip("d")
        print(device_ids)
        device = torch.device("cuda:{0}".format(device_ids))

    model = build_model(cfg)
    para_dict = torch.load(r'/usr/demo/common_data/baseline_epoch363.pth')

    if num_gpus > 1:
        model = torch.nn.DataParallel(model)
    if cfg.SOLVER.SYNCBN:
        model = convert_model(model)
    model.load_state_dict(para_dict)

    main_transform = A.Compose \
            ([
            A.Resize(cfg.INPUT.RESIZE_TEST[0], cfg.INPUT.RESIZE_TEST[1]),
            # A.CenterCrop(cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1])
        ])
    image_transform = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD),
    ])

    dataset = Fashion_MNIST_DataSet(cfg, mode='val', main_transform=main_transform,
                                        img_transform=image_transform)
    dataloader = DataLoader(
Esempio n. 3
0
    def __init__(self, cfg, model, train_dl, val_dl, loss_func, num_query,
                 num_gpus):
        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.loss_func = loss_func
        self.num_query = num_query

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.train_epoch = 1
        self.batch_cnt = 0

        self.logger = logging.getLogger('reid_baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR
        self.device = cfg.MODEL.DEVICE
        self.epochs = cfg.SOLVER.MAX_EPOCHS

        if cfg.SOLVER.TENSORBOARD.USE:
            summary_dir = os.path.join(cfg.OUTPUT_DIR, 'summaries/')
            os.makedirs(summary_dir, exist_ok=True)
            self.summary_writer = SummaryWriter(log_dir=summary_dir)
        self.current_iteration = 0

        self.model.cuda()
        self.logger.info(self.model)

        if num_gpus > 1:

            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_EPOCH,
                                               cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.optim)

            self.mix_precision = (cfg.MODEL.OPT_LEVEL != "O0")
            if self.mix_precision:
                self.model, self.optim = amp.initialize(
                    self.model, self.optim, opt_level=cfg.MODEL.OPT_LEVEL)
                self.logger.info(
                    'Using apex for mix_precision with opt_level {}'.format(
                        cfg.MODEL.OPT_LEVEL))

            self.model = nn.DataParallel(self.model)
            if cfg.SOLVER.SYNCBN:
                if self.mix_precision:
                    self.model = apex.parallel.convert_syncbn_model(self.model)
                    self.logger.info(
                        'More than one gpu used, convert model to use SyncBN.')
                    self.logger.info('Using apex SyncBN implementation')
                else:
                    self.model = convert_model(self.model)
                    self.model.cuda()
                    self.logger.info(
                        'More than one gpu used, convert model to use SyncBN.')
                    self.logger.info('Using pytorch SyncBN implementation')
                    self.logger.info(self.model)

            self.logger.info('Trainer Built')

            return

        else:
            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_EPOCH,
                                               cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.optim)

            self.mix_precision = (cfg.MODEL.OPT_LEVEL != "O0")
            if self.mix_precision:
                self.model, self.optim = amp.initialize(
                    self.model, self.optim, opt_level=cfg.MODEL.OPT_LEVEL)
                self.logger.info(
                    'Using apex for mix_precision with opt_level {}'.format(
                        cfg.MODEL.OPT_LEVEL))

            return
    def __init__(self, cfg, model, train_dl, val_dl, exemplar_dl, loss_func,
                 num_query, num_gpus):
        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.exemplar_dl = exemplar_dl
        self.loss_func = loss_func
        self.num_query = num_query

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.train_epoch = 1
        self.batch_cnt = 0

        self.logger = logging.getLogger('reid_baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR
        self.device = cfg.MODEL.DEVICE
        self.epochs = cfg.SOLVER.MAX_EPOCHS

        self.model.cuda()
        self.logger.info(self.model)
        # ex memory
        self.exemplar_memory = ExemplarMemoryLoss(
            cfg.DATASETS.EXEMPLAR.MEMORY.NUM_FEATS,
            len(exemplar_dl.dataset),
            beta=cfg.DATASETS.EXEMPLAR.MEMORY.BETA,
            knn=cfg.DATASETS.EXEMPLAR.MEMORY.KNN,
            alpha=cfg.DATASETS.EXEMPLAR.MEMORY.ALPHA,
            knn_start_epoch=cfg.DATASETS.EXEMPLAR.MEMORY.KNN_START_EPOCH)
        self.exemplar_memory.cuda()
        self.logger.info(self.exemplar_memory)
        # Target iter
        self.exemplar_iter = iter(exemplar_dl)

        if num_gpus > 1:

            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_EPOCH,
                                               cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.optim)

            self.mix_precision = (cfg.MODEL.OPT_LEVEL != "O0")
            if self.mix_precision:
                self.model, self.optim = amp.initialize(
                    self.model, self.optim, opt_level=cfg.MODEL.OPT_LEVEL)
                self.logger.info(
                    'Using apex for mix_precision with opt_level {}'.format(
                        cfg.MODEL.OPT_LEVEL))

            self.model = nn.DataParallel(self.model)
            if cfg.SOLVER.SYNCBN:
                if self.mix_precision:
                    self.model = apex.parallel.convert_syncbn_model(self.model)
                    self.logger.info(
                        'More than one gpu used, convert model to use SyncBN.')
                    self.logger.info('Using apex SyncBN implementation')
                else:
                    self.model = convert_model(self.model)
                    self.logger.info(
                        'More than one gpu used, convert model to use SyncBN.')
                    self.logger.info('Using pytorch SyncBN implementation')

            self.logger.info('Trainer Built')

            return

        else:

            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_EPOCH,
                                               cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.optim)

            self.mix_precision = (cfg.MODEL.OPT_LEVEL != "O0")
            if self.mix_precision:
                self.model, self.optim = amp.initialize(
                    self.model, self.optim, opt_level=cfg.MODEL.OPT_LEVEL)
                self.logger.info(
                    'Using apex for mix_precision with opt_level {}'.format(
                        cfg.MODEL.OPT_LEVEL))

            return
Esempio n. 5
0
    def __init__(self, cfg, model, train_dl, val_dl,exemplar_dl,
                 loss_func, num_query, num_gpus):
        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.exemplar_dl = exemplar_dl
        self.loss_func = loss_func
        self.num_query = num_query

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.train_epoch = 1
        self.batch_cnt = 0

        self.logger = logging.getLogger('reid_baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR
        self.device = cfg.MODEL.DEVICE
        self.epochs = cfg.SOLVER.MAX_EPOCHS



        self.model.cuda()
        self.logger.info(self.model)
        # Target iter
        self.exemplar_iter = iter(exemplar_dl)

        if num_gpus > 1:
        
            self.optim = make_optimizer(self.model,opt=self.cfg.SOLVER.OPTIMIZER_NAME,lr=cfg.SOLVER.BASE_LR,weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,cfg.SOLVER.WARMUP_EPOCH, cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.optim)

            self.mix_precision = (cfg.MODEL.OPT_LEVEL != "O0")
            if self.mix_precision:
                self.model, self.optim = amp.initialize(self.model, self.optim, opt_level=cfg.MODEL.OPT_LEVEL)
                self.logger.info('Using apex for mix_precision with opt_level {}'.format(cfg.MODEL.OPT_LEVEL))

            self.model = nn.DataParallel(self.model)
            if cfg.SOLVER.SYNCBN:
                if self.mix_precision:
                    self.model = apex.parallel.convert_syncbn_model(self.model)
                    self.logger.info('More than one gpu used, convert model to use SyncBN.')
                    self.logger.info('Using apex SyncBN implementation')
                else:
                    self.model = convert_model(self.model)
                    self.logger.info('More than one gpu used, convert model to use SyncBN.')
                    self.logger.info('Using pytorch SyncBN implementation')
            # [todo] test with mix precision
            if cfg.MODEL.WEIGHT != '':
                self.logger.info('Loading weight from {}'.format(cfg.MODEL.WEIGHT))
                param_dict = torch.load(cfg.MODEL.WEIGHT)
                self.model.load_state_dict(param_dict)

            self.logger.info('Trainer Built')

            return

        else:

            self.optim = make_optimizer(self.model,opt=self.cfg.SOLVER.OPTIMIZER_NAME,lr=cfg.SOLVER.BASE_LR,weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,cfg.SOLVER.WARMUP_EPOCH, cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.optim)
            
            self.mix_precision = (cfg.MODEL.OPT_LEVEL != "O0")
            if self.mix_precision:
                self.model, self.optim = amp.initialize(self.model, self.optim, opt_level=cfg.MODEL.OPT_LEVEL)
                self.logger.info('Using apex for mix_precision with opt_level {}'.format(cfg.MODEL.OPT_LEVEL))
            if cfg.MODEL.WEIGHT != '':
                self.logger.info('Loading weight from {}'.format(cfg.MODEL.WEIGHT))
                param_dict = torch.load(cfg.MODEL.WEIGHT)
                self.model.load_state_dict(param_dict)
            self.logger.info('Trainer Built')

            return
Esempio n. 6
0
    def __init__(self, cfg, model, train_dl, val_dl, loss_func, num_gpus,
                 device):

        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.loss_func = loss_func

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.f1_avg = AvgerageMeter()

        self.val_loss_avg = AvgerageMeter()
        self.val_acc_avg = AvgerageMeter()
        self.device = device

        self.train_epoch = 1

        if cfg.SOLVER.USE_WARMUP:
            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR * 0.1,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
        else:
            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
        if cfg.SOLVER.RESUME:
            print("Resume from checkpoint...")
            checkpoint = torch.load(cfg.SOLVER.RESUME_CHECKPOINT)
            param_dict = checkpoint['model_state_dict']
            self.optim.load_state_dict(checkpoint['optimizer_state_dict'])
            for state in self.optim.state.values():
                for k, v in state.items():
                    print(type(v))
                    if torch.is_tensor(v):
                        state[k] = v.to(self.device)
            self.train_epoch = checkpoint['epoch'] + 1
            for i in param_dict:
                if i.startswith("module"):
                    new_i = i[7:]
                else:
                    new_i = i
                if 'classifier' in i or 'fc' in i:
                    continue
                self.model.state_dict()[new_i].copy_(param_dict[i])

        self.batch_cnt = 0

        self.logger = logging.getLogger('baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR

        self.epochs = cfg.SOLVER.MAX_EPOCHS

        if cfg.SOLVER.TENSORBOARD.USE:
            summary_dir = os.path.join(cfg.OUTPUT_DIR, 'summaries/')
            os.makedirs(summary_dir, exist_ok=True)
            self.summary_writer = SummaryWriter(log_dir=summary_dir)
        self.current_iteration = 0

        self.logger.info(self.model)

        if self.cfg.SOLVER.USE_WARMUP:

            scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optim, self.epochs, eta_min=cfg.SOLVER.MIN_LR)
            self.scheduler = GradualWarmupScheduler(
                self.optim,
                multiplier=10,
                total_epoch=cfg.SOLVER.WARMUP_EPOCH,
                after_scheduler=scheduler_cosine)
            # self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
            #                                cfg.SOLVER.WARMUP_EPOCH, cfg.SOLVER.WARMUP_METHOD)
        else:
            self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optim, self.epochs, eta_min=cfg.SOLVER.MIN_LR)

        if num_gpus > 1:

            self.logger.info(self.optim)
            self.model = nn.DataParallel(self.model)
            if cfg.SOLVER.SYNCBN:
                self.model = convert_model(self.model)
                self.model = self.model.to(device)
                self.logger.info(
                    'More than one gpu used, convert model to use SyncBN.')
                self.logger.info('Using pytorch SyncBN implementation')
                self.logger.info(self.model)

            self.logger.info('Trainer Built')

            return

        else:
            self.model = self.model.to(device)
            self.logger.info('Cpu used.')
            self.logger.info(self.model)
            self.logger.info('Trainer Built')

            return
Esempio n. 7
0
    def __init__(self, cfg, model, train_dl, val_dl, loss_func, num_query,
                 num_gpus):
        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.loss_func = loss_func
        self.num_query = num_query

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.train_epoch = 1
        self.batch_cnt = 0

        self.logger = logging.getLogger('reid_baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR
        self.device = cfg.MODEL.DEVICE
        self.epochs = cfg.SOLVER.MAX_EPOCHS

        if cfg.SOLVER.TENSORBOARD.USE:
            summary_dir = os.path.join(cfg.OUTPUT_DIR, 'summaries/')
            os.makedirs(summary_dir, exist_ok=True)
            self.summary_writer = SummaryWriter(log_dir=summary_dir)
        self.current_iteration = 0

        self.model.cuda()
        self.logger.info(self.model)

        if num_gpus > 1:

            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_EPOCH,
                                               cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.optim)

            self.mix_precision = (cfg.MODEL.OPT_LEVEL != "O0")
            if self.mix_precision:
                self.model, self.optim = amp.initialize(
                    self.model, self.optim, opt_level=cfg.MODEL.OPT_LEVEL)
                self.logger.info(
                    'Using apex for mix_precision with opt_level {}'.format(
                        cfg.MODEL.OPT_LEVEL))

            self.model = nn.DataParallel(self.model)
            if cfg.SOLVER.SYNCBN:
                if self.mix_precision:
                    self.model = apex.parallel.convert_syncbn_model(self.model)
                    self.logger.info(
                        'More than one gpu used, convert model to use SyncBN.')
                    self.logger.info('Using apex SyncBN implementation')
                else:
                    self.model = convert_model(self.model)
                    self.model.cuda()
                    self.logger.info(
                        'More than one gpu used, convert model to use SyncBN.')
                    self.logger.info('Using pytorch SyncBN implementation')
                    self.logger.info(self.model)
            # [todo] test with mix precision
            if cfg.MODEL.WEIGHT != '':
                self.logger.info('Loading weight from {}'.format(
                    cfg.MODEL.WEIGHT))
                param_dict = torch.load(cfg.MODEL.WEIGHT)

                start_with_module = False
                for k in param_dict.keys():
                    if k.startswith('module.'):
                        start_with_module = True
                        break
                if start_with_module:
                    param_dict = {k[7:]: v for k, v in param_dict.items()}

                print('ignore_param:')
                print([
                    k for k, v in param_dict.items()
                    if k not in self.state_dict()
                    or self.state_dict()[k].size() != v.size()
                ])
                print('unload_param:')
                print([
                    k for k, v in self.state_dict().items()
                    if k not in param_dict.keys()
                    or param_dict[k].size() != v.size()
                ])

                param_dict = {
                    k: v
                    for k, v in param_dict.items() if k in self.state_dict()
                    and self.state_dict()[k].size() == v.size()
                }
                for i in param_dict:
                    self.model.state_dict()[i].copy_(param_dict[i])
                # self.model.load_state_dict(param_dict)
            self.logger.info('Trainer Built')

            return

        else:
            if cfg.SOLVER.FIX_BACKBONE:

                print('==>fix backbone')
                param_list = []
                for k, v in self.model.named_parameters():
                    if 'reduction_' not in k and 'fc_id_' not in k:
                        v.requires_grad = False  #固定参数
                    else:
                        param_list.append(v)
                        print(k)
                self.optim = make_optimizer_partial(
                    param_list,
                    opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                    lr=cfg.SOLVER.BASE_LR,
                    weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                    momentum=0.9)

            else:
                self.optim = make_optimizer(
                    self.model,
                    opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                    lr=cfg.SOLVER.BASE_LR,
                    weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                    momentum=0.9)
            self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS,
                                               cfg.SOLVER.GAMMA,
                                               cfg.SOLVER.WARMUP_FACTOR,
                                               cfg.SOLVER.WARMUP_EPOCH,
                                               cfg.SOLVER.WARMUP_METHOD)
            self.logger.info(self.optim)

            self.mix_precision = (cfg.MODEL.OPT_LEVEL != "O0")
            if self.mix_precision:
                self.model, self.optim = amp.initialize(
                    self.model, self.optim, opt_level=cfg.MODEL.OPT_LEVEL)
                self.logger.info(
                    'Using apex for mix_precision with opt_level {}'.format(
                        cfg.MODEL.OPT_LEVEL))

            if cfg.MODEL.WEIGHT != '':
                self.logger.info('Loading weight from {}'.format(
                    cfg.MODEL.WEIGHT))
                param_dict = torch.load(cfg.MODEL.WEIGHT)

                start_with_module = False
                for k in param_dict.keys():
                    if k.startswith('module.'):
                        start_with_module = True
                        break
                if start_with_module:
                    param_dict = {k[7:]: v for k, v in param_dict.items()}

                print('ignore_param:')
                print([
                    k for k, v in param_dict.items()
                    if k not in self.model.state_dict()
                    or self.model.state_dict()[k].size() != v.size()
                ])
                print('unload_param:')
                print([
                    k for k, v in self.model.state_dict().items()
                    if k not in param_dict.keys()
                    or param_dict[k].size() != v.size()
                ])

                param_dict = {
                    k: v
                    for k, v in param_dict.items()
                    if k in self.model.state_dict()
                    and self.model.state_dict()[k].size() == v.size()
                }
                for i in param_dict:
                    self.model.state_dict()[i].copy_(param_dict[i])
                # for k,v in self.model.named_parameters():
                #     if 'reduction_' not in k:
                #         print(v.requires_grad)#理想状态下,所有值都是False
            return