コード例 #1
0
    def __init__(self, cfg_path):
        with open(cfg_path, 'r') as rf:
            self.cfg = yaml.safe_load(rf)
        self.data_cfg = self.cfg['data']
        self.model_cfg = self.cfg['model']
        self.optim_cfg = self.cfg['optim']
        self.hyper_params = self.cfg['hyper_params']
        self.val_cfg = self.cfg['val']
        print(self.data_cfg)
        print(self.model_cfg)
        print(self.optim_cfg)
        print(self.hyper_params)
        print(self.val_cfg)
        os.environ['CUDA_VISIBLE_DEVICES'] = self.cfg['gpus']
        dist.init_process_group(backend='nccl')
        self.tdata = COCODataSets(img_root=self.data_cfg['train_img_root'],
                                  annotation_path=self.data_cfg['train_annotation_path'],
                                  img_size=self.data_cfg['img_size'],
                                  debug=self.data_cfg['debug'],
                                  use_crowd=self.data_cfg['use_crowd'],
                                  augments=True,
                                  remove_blank=self.data_cfg['remove_blank']
                                  )
        self.tloader = DataLoader(dataset=self.tdata,
                                  batch_size=self.data_cfg['batch_size'],
                                  num_workers=self.data_cfg['num_workers'],
                                  collate_fn=self.tdata.collate_fn,
                                  sampler=DistributedSampler(dataset=self.tdata, shuffle=True))
        self.vdata = COCODataSets(img_root=self.data_cfg['val_img_root'],
                                  annotation_path=self.data_cfg['val_annotation_path'],
                                  img_size=self.data_cfg['img_size'],
                                  debug=self.data_cfg['debug'],
                                  use_crowd=self.data_cfg['use_crowd'],
                                  augments=False,
                                  remove_blank=False
                                  )
        self.vloader = DataLoader(dataset=self.vdata,
                                  batch_size=self.data_cfg['batch_size'],
                                  num_workers=self.data_cfg['num_workers'],
                                  collate_fn=self.vdata.collate_fn,
                                  sampler=DistributedSampler(dataset=self.vdata, shuffle=False))
        print("train_data: ", len(self.tdata), " | ",
              "val_data: ", len(self.vdata), " | ",
              "empty_data: ", self.tdata.empty_images_len)
        print("train_iter: ", len(self.tloader), " | ",
              "val_iter: ", len(self.vloader))
        model = RetinaNet(num_cls=self.model_cfg['num_cls'],
                          anchor_ratios=self.model_cfg['anchor_ratios'],
                          anchor_scales=self.model_cfg['anchor_scales'],
                          anchor_sizes=self.model_cfg['anchor_sizes'],
                          strides=self.model_cfg['strides'],
                          backbone=self.model_cfg['backbone'],
                          )
        if self.model_cfg.get("backbone_weight", None):
            weights = torch.load(self.model_cfg['backbone_weight'])
            model.load_backbone_weighs(weights)
            
        self.best_map = 0.
        self.best_map50 = 0.
        optimizer = split_optimizer(model, self.optim_cfg)
        local_rank = dist.get_rank()
        self.local_rank = local_rank
        self.device = torch.device("cuda", local_rank)
        model.to(self.device)
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
        if self.optim_cfg['sync_bn']:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        self.model = nn.parallel.distributed.DistributedDataParallel(model,
                                                                     device_ids=[local_rank],
                                                                     output_device=local_rank)
        self.optimizer = optimizer
        self.ema = ModelEMA(self.model)

        beta = eval(self.hyper_params['beta']) if isinstance(self.hyper_params['beta'], str) \
            else self.hyper_params['beta']

        self.creterion = ATSSRetinaLoss(top_k=self.hyper_params['top_k'],
                                        anchor_per_loc=len(self.model_cfg['anchor_ratios']) * len(self.model_cfg[
                                                                                                      'anchor_scales']),
                                        alpha=self.hyper_params['alpha'],
                                        gamma=self.hyper_params['gamma'],
                                        beta=beta,
                                        )
        self.lr_adjuster = WarmUpCosineDecayMultiStepLRAdjust(init_lr=self.optim_cfg['lr'],
                                                              milestones=self.optim_cfg['milestones'],
                                                              warm_up_epoch=self.optim_cfg['warm_up_epoch'],
                                                              iter_per_epoch=len(self.tloader),
                                                              epochs=self.optim_cfg['epochs'],
                                                              cosine_weights=self.optim_cfg['cosine_weights']
                                                              )
        
        self.tb_writer = None
        if self.local_rank == 0:
            log_dir = 'runs/'
            print('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % log_dir)
            self.tb_writer = SummaryWriter(log_dir=log_dir)
コード例 #2
0
    def __init__(self, cfg_path):
        with open(cfg_path, 'r') as rf:
            self.cfg = yaml.safe_load(rf)
        self.data_cfg = self.cfg['data']
        self.model_cfg = self.cfg['model']
        self.optim_cfg = self.cfg['optim']
        self.hyper_params = self.cfg['hyper_params']
        self.val_cfg = self.cfg['val']
        print(self.data_cfg)
        print(self.model_cfg)
        print(self.optim_cfg)
        print(self.hyper_params)
        print(self.val_cfg)

        self.tdata = COCODataSets(img_root=self.data_cfg['train_img_root'],
                                  annotation_path=self.data_cfg['train_annotation_path'],
                                  img_size=self.data_cfg['img_size'],
                                  debug=self.data_cfg['debug'],
                                  use_crowd=self.data_cfg['use_crowd'],
                                  augments=True,
                                  remove_blank=self.data_cfg['remove_blank']
                                  )
        self.tloader = DataLoader(dataset=self.tdata,
                                  batch_size=self.data_cfg['batch_size'],
                                  num_workers=self.data_cfg['num_workers'],
                                  collate_fn=self.tdata.collate_fn,
                                  shuffle=True
                                  )
        self.vdata = COCODataSets(img_root=self.data_cfg['val_img_root'],
                                  annotation_path=self.data_cfg['val_annotation_path'],
                                  img_size=self.data_cfg['img_size'],
                                  debug=self.data_cfg['debug'],
                                  use_crowd=self.data_cfg['use_crowd'],
                                  augments=False,
                                  remove_blank=False
                                  )
        self.vloader = DataLoader(dataset=self.vdata,
                                  batch_size=self.data_cfg['batch_size'],
                                  num_workers=self.data_cfg['num_workers'],
                                  collate_fn=self.vdata.collate_fn,
                                  shuffle=False
                                  )
        print("train_data: ", len(self.tdata), " | ",
              "val_data: ", len(self.vdata), " | ",
              "empty_data: ", self.tdata.empty_images_len)
        print("train_iter: ", len(self.tloader), " | ",
              "val_iter: ", len(self.vloader))
        model = RetinaNet(num_cls=self.model_cfg['num_cls'],
                          anchor_ratios=self.model_cfg['anchor_ratios'],
                          anchor_scales=self.model_cfg['anchor_scales'],
                          anchor_sizes=self.model_cfg['anchor_sizes'],
                          strides=self.model_cfg['strides'],
                          backbone=self.model_cfg['backbone'],
                          )
        # if self.model_cfg.get("backbone_weight", None):
        #     weights = torch.load(self.model_cfg['backbone_weight'])
        #     model.load_backbone_weighs(weights)
        self.best_map = 0.
        self.best_map50 = 0.
        optimizer = split_optimizer(model, self.optim_cfg)

        self.local_rank = 0
        self.device = torch.device("cuda:0")
        model.to(self.device)

        self.model = model
        self.optimizer = optimizer
        self.ema = ModelEMA(self.model)

        beta = eval(self.hyper_params['beta']) if isinstance(self.hyper_params['beta'], str) \
            else self.hyper_params['beta']

        self.creterion = ATSSRetinaLoss(top_k=self.hyper_params['top_k'],
                                        anchor_per_loc=len(self.model_cfg['anchor_ratios']) * len(self.model_cfg[
                                                                                                      'anchor_scales']),
                                        alpha=self.hyper_params['alpha'],
                                        gamma=self.hyper_params['gamma'],
                                        beta=beta,
                                        )
        self.lr_adjuster = WarmUpCosineDecayMultiStepLRAdjust(init_lr=self.optim_cfg['lr'],
                                                              milestones=self.optim_cfg['milestones'],
                                                              warm_up_epoch=self.optim_cfg['warm_up_epoch'],
                                                              iter_per_epoch=len(self.tloader),
                                                              epochs=self.optim_cfg['epochs'],
                                                              cosine_weights=self.optim_cfg['cosine_weights']
                                                              )
コード例 #3
0
class DDPApexProcessor(object):
    def __init__(self, cfg_path):
        with open(cfg_path, 'r') as rf:
            self.cfg = yaml.safe_load(rf)
        self.data_cfg = self.cfg['data']
        self.model_cfg = self.cfg['model']
        self.optim_cfg = self.cfg['optim']
        self.hyper_params = self.cfg['hyper_params']
        self.val_cfg = self.cfg['val']
        print(self.data_cfg)
        print(self.model_cfg)
        print(self.optim_cfg)
        print(self.hyper_params)
        print(self.val_cfg)
        os.environ['CUDA_VISIBLE_DEVICES'] = self.cfg['gpus']
        dist.init_process_group(backend='nccl')
        self.tdata = COCODataSets(img_root=self.data_cfg['train_img_root'],
                                  annotation_path=self.data_cfg['train_annotation_path'],
                                  img_size=self.data_cfg['img_size'],
                                  debug=self.data_cfg['debug'],
                                  use_crowd=self.data_cfg['use_crowd'],
                                  augments=True,
                                  remove_blank=self.data_cfg['remove_blank']
                                  )
        self.tloader = DataLoader(dataset=self.tdata,
                                  batch_size=self.data_cfg['batch_size'],
                                  num_workers=self.data_cfg['num_workers'],
                                  collate_fn=self.tdata.collate_fn,
                                  sampler=DistributedSampler(dataset=self.tdata, shuffle=True))
        self.vdata = COCODataSets(img_root=self.data_cfg['val_img_root'],
                                  annotation_path=self.data_cfg['val_annotation_path'],
                                  img_size=self.data_cfg['img_size'],
                                  debug=self.data_cfg['debug'],
                                  use_crowd=self.data_cfg['use_crowd'],
                                  augments=False,
                                  remove_blank=False
                                  )
        self.vloader = DataLoader(dataset=self.vdata,
                                  batch_size=self.data_cfg['batch_size'],
                                  num_workers=self.data_cfg['num_workers'],
                                  collate_fn=self.vdata.collate_fn,
                                  sampler=DistributedSampler(dataset=self.vdata, shuffle=False))
        print("train_data: ", len(self.tdata), " | ",
              "val_data: ", len(self.vdata), " | ",
              "empty_data: ", self.tdata.empty_images_len)
        print("train_iter: ", len(self.tloader), " | ",
              "val_iter: ", len(self.vloader))
        model = RetinaNet(num_cls=self.model_cfg['num_cls'],
                          anchor_ratios=self.model_cfg['anchor_ratios'],
                          anchor_scales=self.model_cfg['anchor_scales'],
                          anchor_sizes=self.model_cfg['anchor_sizes'],
                          strides=self.model_cfg['strides'],
                          backbone=self.model_cfg['backbone'],
                          )
        if self.model_cfg.get("backbone_weight", None):
            weights = torch.load(self.model_cfg['backbone_weight'])
            model.load_backbone_weighs(weights)
            
        self.best_map = 0.
        self.best_map50 = 0.
        optimizer = split_optimizer(model, self.optim_cfg)
        local_rank = dist.get_rank()
        self.local_rank = local_rank
        self.device = torch.device("cuda", local_rank)
        model.to(self.device)
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
        if self.optim_cfg['sync_bn']:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        self.model = nn.parallel.distributed.DistributedDataParallel(model,
                                                                     device_ids=[local_rank],
                                                                     output_device=local_rank)
        self.optimizer = optimizer
        self.ema = ModelEMA(self.model)

        beta = eval(self.hyper_params['beta']) if isinstance(self.hyper_params['beta'], str) \
            else self.hyper_params['beta']

        self.creterion = ATSSRetinaLoss(top_k=self.hyper_params['top_k'],
                                        anchor_per_loc=len(self.model_cfg['anchor_ratios']) * len(self.model_cfg[
                                                                                                      'anchor_scales']),
                                        alpha=self.hyper_params['alpha'],
                                        gamma=self.hyper_params['gamma'],
                                        beta=beta,
                                        )
        self.lr_adjuster = WarmUpCosineDecayMultiStepLRAdjust(init_lr=self.optim_cfg['lr'],
                                                              milestones=self.optim_cfg['milestones'],
                                                              warm_up_epoch=self.optim_cfg['warm_up_epoch'],
                                                              iter_per_epoch=len(self.tloader),
                                                              epochs=self.optim_cfg['epochs'],
                                                              cosine_weights=self.optim_cfg['cosine_weights']
                                                              )
        
        self.tb_writer = None
        if self.local_rank == 0:
            log_dir = 'runs/'
            print('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % log_dir)
            self.tb_writer = SummaryWriter(log_dir=log_dir)
        
        

    def train(self, epoch):
        self.model.train()
        if self.model_cfg['freeze_bn']:
            self.model.apply(freeze_bn)
        if self.local_rank == 0:
            pbar = tqdm(self.tloader)
        else:
            pbar = self.tloader
        loss_list = [list(), list(), list()]
        lr = 0
        match_num = 0
        for i, (img_tensor, targets_tensor, _) in enumerate(pbar):
            if len(self.hyper_params['multi_scale']) > 2:
                target_size = np.random.choice(self.hyper_params['multi_scale'])
                img_tensor = interpolate(img_tensor, mode='bilinear', size=target_size, align_corners=False)
            _, _, h, w = img_tensor.shape
            with torch.no_grad():
                targets_tensor[:, 3:] = targets_tensor[:, 3:] * torch.tensor(data=[w, h, w, h])
                img_tensor = img_tensor.to(self.device)
                targets_tensor = targets_tensor.to(self.device)
            self.optimizer.zero_grad()
            cls_predicts, reg_predicts, center_predicts, anchors = self.model(img_tensor)
            total_loss, detail_loss, total_num = self.creterion(cls_predicts, reg_predicts, center_predicts, anchors,targets_tensor)

            match_num += total_num
            with amp.scale_loss(total_loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
#             nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), max_norm=self.optim_cfg['max_norm'],
#                                      norm_type=2)
            self.lr_adjuster(self.optimizer, i, epoch)
            lr = self.optimizer.param_groups[0]['lr']
            lr1 = self.optimizer.param_groups[1]['lr']
            lr2 = self.optimizer.param_groups[2]['lr']
            self.optimizer.step()
            self.ema.update(self.model)
            loss_cls, loss_reg, los_centerness = detail_loss
            loss_list[0].append(total_loss.item())
            loss_list[1].append(loss_cls.item())
            loss_list[2].append(loss_reg.item())
            
            if self.tb_writer:
                tags = ['train/box_loss', 'train/cls_loss', 'train/center_loss','train/total_loss', 'x/lr0', 'x/lr1', 'x/lr2']
                for x, tag in zip([loss_reg.item(), loss_cls.item(), los_centerness.item(), total_loss.item(), lr, lr1, lr2],
                                  tags):
                    self.tb_writer.add_scalar(tag, x, i + 1 + epoch * len(self.tloader))
            
            if self.local_rank == 0:
                pbar.set_description(
                    "epoch:{:2d}|match_num:{:4d}|size:{:3d}|target_loss:{:6.4f}|loss_cls:{:6.4f}|loss_reg:{:6.4f}|lr:{:8.6f}".format(
                        epoch + 1,
                        total_num,
                        h,
                        total_loss.item(),
                        loss_cls.item(),
                        loss_reg.item(),
                        lr
                    ))
             
        self.ema.update_attr(self.model)
        mean_loss_list = [np.array(item).mean() for item in loss_list]
        print(
            "epoch:{:3d}|match_num:{:4d}|local:{:3d}|target_loss:{:6.4f}|loss_cls:{:6.4f}|loss_reg:{:6.4f}|lr:{:8.6f}"
                .format(epoch + 1,
                        match_num,
                        self.local_rank,
                        mean_loss_list[0],
                        mean_loss_list[1],
                        mean_loss_list[2],
                        lr))

    @torch.no_grad()
    def val(self, epoch):
        predict_list = list()
        target_list = list()
        self.model.eval()
        if self.local_rank == 0:
            pbar = tqdm(self.vloader)
        else:
            pbar = self.vloader
        for img_tensor, targets_tensor, _ in pbar:
            _, _, h, w = img_tensor.shape
            targets_tensor[:, 3:] = targets_tensor[:, 3:] * torch.tensor(data=[w, h, w, h])
            img_tensor = img_tensor.to(self.device)
            targets_tensor = targets_tensor.to(self.device)
            predicts = self.model(img_tensor)

            for i in range(len(predicts)):
                predicts[i][:, [0, 2]] = predicts[i][:, [0, 2]].clamp(min=0, max=w)
                predicts[i][:, [1, 3]] = predicts[i][:, [1, 3]].clamp(min=0, max=h)
            predicts = fcos_non_max_suppression(predicts,
                                                  conf_thresh=self.val_cfg['conf_thresh'],
                                                  iou_thresh=self.val_cfg['iou_thresh'],
                                                  max_det=self.val_cfg['max_det'],
                                                  )
            for i, predict in enumerate(predicts):
                predict_list.append(predict)
                targets_sample = targets_tensor[targets_tensor[:, 0] == i][:, 2:]
                target_list.append(targets_sample)
        mp, mr, map50, map = coco_map(predict_list, target_list)
        if self.tb_writer:
            tags = ['metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95']
            for x, tag in zip([mp, mr, map50, map], tags):
                self.tb_writer.add_scalar(tag, x, epoch + 1)
        print("epoch: {:2d}|local:{:d}|mp:{:6.4f}|mr:{:6.4f}|map50:{:6.4f}|map:{:6.4f}"
              .format(epoch + 1,
                      self.local_rank,
                      mp * 100,
                      mr * 100,
                      map50 * 100,
                      map * 100))
        last_weight_path = os.path.join(self.val_cfg['weight_path'],
                                        "retina",
                                        "{:d}_{:s}_last.pth"
                                        .format(self.local_rank, self.cfg['model_name']))
        best_map_weight_path = os.path.join(self.val_cfg['weight_path'],
                                            "retina",
                                            "{:d}_{:s}_best_map.pth"
                                            .format(self.local_rank, self.cfg['model_name']))
        best_map50_weight_path = os.path.join(self.val_cfg['weight_path'],
                                              "retina",
                                              "{:d}_{:s}_best_map50.pth"
                                              .format(self.local_rank, self.cfg['model_name']))
        model_static = self.model.module.state_dict() if is_parallel(self.model) else self.model.state_dict()

        ema_static = self.ema.ema.state_dict()
        cpkt = {
            "ori": model_static,
            "ema": ema_static,
            "map": map * 100,
            "epoch": epoch,
            "map50": map50 * 100
        }
        if self.local_rank != 0:
            return
        torch.save(cpkt, last_weight_path)
        if map > self.best_map:
            torch.save(cpkt, best_map_weight_path)
            self.best_map = map
        if map50 > self.best_map50:
            torch.save(cpkt, best_map50_weight_path)
            self.best_map50 = map50

    def run(self):
        for epoch in range(self.optim_cfg['epochs']):
            self.train(epoch)
            if (epoch + 1) % self.val_cfg['interval'] == 0:
                self.val(epoch)
コード例 #4
0
class DDPApexProcessor(object):
    def __init__(self, cfg_path):
        with open(cfg_path, 'r') as rf:
            self.cfg = yaml.safe_load(rf)
        self.data_cfg = self.cfg['data']
        self.model_cfg = self.cfg['model']
        self.optim_cfg = self.cfg['optim']
        self.hyper_params = self.cfg['hyper_params']
        self.val_cfg = self.cfg['val']
        print(self.data_cfg)
        print(self.model_cfg)
        print(self.optim_cfg)
        print(self.hyper_params)
        print(self.val_cfg)

        self.tdata = COCODataSets(img_root=self.data_cfg['train_img_root'],
                                  annotation_path=self.data_cfg['train_annotation_path'],
                                  img_size=self.data_cfg['img_size'],
                                  debug=self.data_cfg['debug'],
                                  use_crowd=self.data_cfg['use_crowd'],
                                  augments=True,
                                  remove_blank=self.data_cfg['remove_blank']
                                  )
        self.tloader = DataLoader(dataset=self.tdata,
                                  batch_size=self.data_cfg['batch_size'],
                                  num_workers=self.data_cfg['num_workers'],
                                  collate_fn=self.tdata.collate_fn,
                                  shuffle=True
                                  )
        self.vdata = COCODataSets(img_root=self.data_cfg['val_img_root'],
                                  annotation_path=self.data_cfg['val_annotation_path'],
                                  img_size=self.data_cfg['img_size'],
                                  debug=self.data_cfg['debug'],
                                  use_crowd=self.data_cfg['use_crowd'],
                                  augments=False,
                                  remove_blank=False
                                  )
        self.vloader = DataLoader(dataset=self.vdata,
                                  batch_size=self.data_cfg['batch_size'],
                                  num_workers=self.data_cfg['num_workers'],
                                  collate_fn=self.vdata.collate_fn,
                                  shuffle=False
                                  )
        print("train_data: ", len(self.tdata), " | ",
              "val_data: ", len(self.vdata), " | ",
              "empty_data: ", self.tdata.empty_images_len)
        print("train_iter: ", len(self.tloader), " | ",
              "val_iter: ", len(self.vloader))
        model = RetinaNet(num_cls=self.model_cfg['num_cls'],
                          anchor_ratios=self.model_cfg['anchor_ratios'],
                          anchor_scales=self.model_cfg['anchor_scales'],
                          anchor_sizes=self.model_cfg['anchor_sizes'],
                          strides=self.model_cfg['strides'],
                          backbone=self.model_cfg['backbone'],
                          )
        # if self.model_cfg.get("backbone_weight", None):
        #     weights = torch.load(self.model_cfg['backbone_weight'])
        #     model.load_backbone_weighs(weights)
        self.best_map = 0.
        self.best_map50 = 0.
        optimizer = split_optimizer(model, self.optim_cfg)

        self.local_rank = 0
        self.device = torch.device("cuda:0")
        model.to(self.device)

        self.model = model
        self.optimizer = optimizer
        self.ema = ModelEMA(self.model)

        beta = eval(self.hyper_params['beta']) if isinstance(self.hyper_params['beta'], str) \
            else self.hyper_params['beta']

        self.creterion = ATSSRetinaLoss(top_k=self.hyper_params['top_k'],
                                        anchor_per_loc=len(self.model_cfg['anchor_ratios']) * len(self.model_cfg[
                                                                                                      'anchor_scales']),
                                        alpha=self.hyper_params['alpha'],
                                        gamma=self.hyper_params['gamma'],
                                        beta=beta,
                                        )
        self.lr_adjuster = WarmUpCosineDecayMultiStepLRAdjust(init_lr=self.optim_cfg['lr'],
                                                              milestones=self.optim_cfg['milestones'],
                                                              warm_up_epoch=self.optim_cfg['warm_up_epoch'],
                                                              iter_per_epoch=len(self.tloader),
                                                              epochs=self.optim_cfg['epochs'],
                                                              cosine_weights=self.optim_cfg['cosine_weights']
                                                              )

    def train(self, epoch):
        self.model.train()
        if self.model_cfg['freeze_bn']:
            self.model.apply(freeze_bn)
        if self.local_rank == 0:
            pbar = tqdm(self.tloader)
        else:
            pbar = self.tloader
        loss_list = [list(), list(), list()]
        lr = 0
        match_num = 0
        for i, (img_tensor, targets_tensor, _) in enumerate(pbar):
            if len(self.hyper_params['multi_scale']) > 2:
                target_size = np.random.choice(self.hyper_params['multi_scale'])
                img_tensor = interpolate(img_tensor, mode='bilinear', size=target_size, align_corners=False)
            _, _, h, w = img_tensor.shape
            with torch.no_grad():
                targets_tensor[:, 3:] = targets_tensor[:, 3:] * torch.tensor(data=[w, h, w, h])
                img_tensor = img_tensor.to(self.device)
                targets_tensor = targets_tensor.to(self.device)
            self.optimizer.zero_grad()
            cls_predicts, reg_predicts, center_predicts, anchors = self.model(img_tensor)
            total_loss, detail_loss, total_num = self.creterion(cls_predicts, reg_predicts, center_predicts, anchors, targets_tensor)
            match_num += total_num
            total_loss.backward()
            self.lr_adjuster(self.optimizer, i, epoch)
            lr = self.optimizer.param_groups[0]['lr']
            self.optimizer.step()
            self.ema.update(self.model)
            loss_cls, loss_reg, _ = detail_loss
            loss_list[0].append(total_loss.item())
            loss_list[1].append(loss_cls.item())
            loss_list[2].append(loss_reg.item())
            if self.local_rank == 0:
                pbar.set_description(
                    "epoch:{:2d}|match_num:{:4d}|size:{:3d}|target_loss:{:6.4f}|loss_cls:{:6.4f}|loss_reg:{:6.4f}|lr:{:8.6f}".format(
                        epoch + 1,
                        total_num,
                        h,
                        total_loss.item(),
                        loss_cls.item(),
                        loss_reg.item(),
                        lr
                    ))
        self.ema.update_attr(self.model)
        mean_loss_list = [np.array(item).mean() for item in loss_list]
        print(
            "epoch:{:3d}|match_num:{:4d}|local:{:3d}|target_loss:{:6.4f}|loss_cls:{:6.4f}|loss_reg:{:6.4f}|lr:{:8.6f}"
                .format(epoch + 1,
                        match_num,
                        self.local_rank,
                        mean_loss_list[0],
                        mean_loss_list[1],
                        mean_loss_list[2],
                        lr))

    @torch.no_grad()
    def val(self, epoch):
        predict_list = list()
        target_list = list()
        self.model.eval()
        if self.local_rank == 0:
            pbar = tqdm(self.vloader)
        else:
            pbar = self.vloader
        for img_tensor, targets_tensor, _ in pbar:
            _, _, h, w = img_tensor.shape
            targets_tensor[:, 3:] = targets_tensor[:, 3:] * torch.tensor(data=[w, h, w, h])
            img_tensor = img_tensor.to(self.device)
            targets_tensor = targets_tensor.to(self.device)
            predicts = self.model(img_tensor)

            for i in range(len(predicts)):
                predicts[i][:, [0, 2]] = predicts[i][:, [0, 2]].clamp(min=0, max=w)
                predicts[i][:, [1, 3]] = predicts[i][:, [1, 3]].clamp(min=0, max=h)
            predicts = fcos_non_max_suppression(predicts,
                                                  conf_thresh=self.val_cfg['conf_thresh'],
                                                  iou_thresh=self.val_cfg['iou_thresh'],
                                                  max_det=self.val_cfg['max_det'],
                                                  )
            for i, predict in enumerate(predicts):
                predict_list.append(predict)
                targets_sample = targets_tensor[targets_tensor[:, 0] == i][:, 2:]
                target_list.append(targets_sample)
        mp, mr, map50, map = coco_map(predict_list, target_list)
        print("epoch: {:2d}|local:{:d}|mp:{:6.4f}|mr:{:6.4f}|map50:{:6.4f}|map:{:6.4f}"
              .format(epoch + 1,
                      self.local_rank,
                      mp * 100,
                      mr * 100,
                      map50 * 100,
                      map * 100))
        last_weight_path = os.path.join(self.val_cfg['weight_path'],
                                        "retina",
                                        "{:d}_{:s}_last.pth"
                                        .format(self.local_rank, self.cfg['model_name']))
        best_map_weight_path = os.path.join(self.val_cfg['weight_path'],
                                            "retina",
                                            "{:d}_{:s}_best_map.pth"
                                            .format(self.local_rank, self.cfg['model_name']))
        best_map50_weight_path = os.path.join(self.val_cfg['weight_path'],
                                              "retina",
                                              "{:d}_{:s}_best_map50.pth"
                                              .format(self.local_rank, self.cfg['model_name']))
        model_static = self.model.module.state_dict() if is_parallel(self.model) else self.model.state_dict()

        ema_static = self.ema.ema.state_dict()
        cpkt = {
            "ori": model_static,
            "ema": ema_static,
            "map": map * 100,
            "epoch": epoch,
            "map50": map50 * 100
        }
        if self.local_rank != 0:
            return
        torch.save(cpkt, last_weight_path)
        if map > self.best_map:
            torch.save(cpkt, best_map_weight_path)
            self.best_map = map
        if map50 > self.best_map50:
            torch.save(cpkt, best_map50_weight_path)
            self.best_map50 = map50

    def run(self):
        for epoch in range(self.optim_cfg['epochs']):
            self.train(epoch)
            if (epoch + 1) % self.val_cfg['interval'] == 0:
                self.val(epoch)