예제 #1
0
파일: detect.py 프로젝트: haok61bkhn/CMND
class Detector_fields:
    def __init__(self):
        opt = get_config()
        self.model = Darknet(opt.cfg)
        self.model.load_weights(opt.weights)
        self.model.to(opt.device)
        self.class_names = load_class_names(opt.names)
        self.size = (self.model.width, self.model.height)
        self.num_classes = 6
        print(self.class_names)

    def detect(self, img, thresh=0.6):
        res = {}
        resimg = {}
        for x in self.class_names:
            resimg[x] = []
            res[x] = []
        im0 = img.copy()
        size = (img.shape[0], img.shape[1])
        img = cv2.resize(img, self.size)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        boxes = []

        type_obj = []
        score = []
        boxes = do_detect(self.model, img, thresh, self.num_classes, thresh, 1)
        res_box = []
        ims = []
        classes = []
        for box in boxes:
            #if(int(box[6])==2  or int(box[6])==3  or int(box[6])==5  or int(box[6])==7  ): # 2 3 5 7  is vehicle
            if (self.class_names[int(box[6])] == "date"):
                margin = 0
            if (self.class_names[int(box[6])] == "id"):
                margin = 3
                #print("id")
            else:
                margin = 0
            x1 = max(int((box[0] - box[2] / 2.0) * size[1]) - margin, 0)
            y1 = max(int((box[1] - box[3] / 2.0) * size[0]) - margin - 1, 0)
            x2 = min(int((box[0] + box[2] / 2.0) * size[1] + margin),
                     im0.shape[1])
            y2 = min(int((box[1] + box[3] / 2.0) * size[0] + margin),
                     im0.shape[0])
            imm = im0[y1:y2, x1:x2]

            # if(imm.shape[0]>20 and imm.shape[1]>20):
            # res_box.append([x1,y1,x2,y2])
            # ims.append(imm)
            # classes.append(self.class_names[int(box[6])])
            res[self.class_names[int(box[6])]].append([x1, y1, x2, y2])
            resimg[self.class_names[int(box[6])]].append(imm)
        return res, resimg
예제 #2
0
def train_yolov4(cfg):
    logging = init_logger(log_dir=cfg.TRAIN_TENSORBOARD_DIR)

    os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    if cfg.use_darknet_cfg:
        model = Darknet(cfg.cfgfile)
    else:
        model = Yolov4(cfg.pretrained, n_classes=cfg.classes)

        #TODO, load checkpoints
        '''
        WORK_FOLDER = '/mnt/bos/modules/perception/emergency_detection'
        weightfile = os.path.join(WORK_FOLDER, 'checkpoints/Yolov4_epoch291.pth')
        pretrained_dict = torch.load(weightfile, map_location=torch.device('cuda'))
        model.load_state_dict(pretrained_dict)
        '''

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to(device=device)

    try:
        train(model=model,
              config=cfg,
              epochs=cfg.TRAIN_EPOCHS,
              device=device, )
    except KeyboardInterrupt:
        torch.save(model.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
예제 #3
0
def train(model,
          device,
          config,
          epochs=5,
          batch_size=1,
          save_cp=True,
          log_step=20,
          img_scale=0.5):
    train_dataset = Yolo_dataset(config.train_label, config, train=True)
    val_dataset = Yolo_dataset(config.val_label, config, train=False)

    n_train = len(train_dataset)
    n_val = len(val_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch // config.subdivisions,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True,
                              drop_last=True,
                              collate_fn=collate)

    val_loader = DataLoader(val_dataset,
                            batch_size=config.batch // config.subdivisions,
                            shuffle=True,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True,
                            collate_fn=val_collate)

    writer = SummaryWriter(
        log_dir=config.TRAIN_TENSORBOARD_DIR,
        filename_suffix=
        f'OPT_{config.TRAIN_OPTIMIZER}_LR_{config.learning_rate}_BS_{config.batch}_Sub_{config.subdivisions}_Size_{config.width}',
        comment=
        f'OPT_{config.TRAIN_OPTIMIZER}_LR_{config.learning_rate}_BS_{config.batch}_Sub_{config.subdivisions}_Size_{config.width}'
    )
    # writer.add_images('legend',
    #                   torch.from_numpy(train_dataset.label2colorlegend2(cfg.DATA_CLASSES).transpose([2, 0, 1])).to(
    #                       device).unsqueeze(0))
    max_itr = config.TRAIN_EPOCHS * n_train
    # global_step = cfg.TRAIN_MINEPOCH * n_train
    global_step = 0
    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {config.batch}
        Subdivisions:    {config.subdivisions}
        Learning rate:   {config.learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images size:     {config.width}
        Optimizer:       {config.TRAIN_OPTIMIZER}
        Dataset classes: {config.classes}
        Train label path:{config.train_label}
        Pretrained:
    ''')

    # learning rate setup
    def burnin_schedule(i):
        if i < config.burn_in:
            factor = pow(i / config.burn_in, 4)
        elif i < config.steps[0]:
            factor = 1.0
        elif i < config.steps[1]:
            factor = 0.1
        else:
            factor = 0.01
        return factor

    if config.TRAIN_OPTIMIZER.lower() == 'adam':
        optimizer = optim.Adam(
            model.parameters(),
            lr=config.learning_rate / config.batch,
            betas=(0.9, 0.999),
            eps=1e-08,
        )
    elif config.TRAIN_OPTIMIZER.lower() == 'sgd':
        optimizer = optim.SGD(
            params=model.parameters(),
            lr=config.learning_rate / config.batch,
            momentum=config.momentum,
            weight_decay=config.decay,
        )
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)

    criterion = Yolo_loss(device=device,
                          batch=config.batch // config.subdivisions,
                          n_classes=config.classes)
    # scheduler = ReduceLROnPlateau(optimizer, mode='max', verbose=True, patience=6, min_lr=1e-7)
    # scheduler = CosineAnnealingWarmRestarts(optimizer, 0.001, 1e-6, 20)

    save_prefix = 'Yolov4_epoch'
    saved_models = deque()
    model.train()
    for epoch in range(epochs):
        # model.train()
        epoch_loss = 0
        epoch_step = 0

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img',
                  ncols=50) as pbar:
            for i, batch in enumerate(train_loader):
                global_step += 1
                epoch_step += 1
                images = batch[0]
                bboxes = batch[1]

                images = images.to(device=device, dtype=torch.float32)
                bboxes = bboxes.to(device=device)

                bboxes_pred = model(images)
                loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2 = criterion(
                    bboxes_pred, bboxes)
                # loss = loss / config.subdivisions
                loss.backward()

                epoch_loss += loss.item()

                if global_step % config.subdivisions == 0:
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()

                if global_step % (log_step * config.subdivisions) == 0:
                    writer.add_scalar('train/Loss', loss.item(), global_step)
                    writer.add_scalar('train/loss_xy', loss_xy.item(),
                                      global_step)
                    writer.add_scalar('train/loss_wh', loss_wh.item(),
                                      global_step)
                    writer.add_scalar('train/loss_obj', loss_obj.item(),
                                      global_step)
                    writer.add_scalar('train/loss_cls', loss_cls.item(),
                                      global_step)
                    writer.add_scalar('train/loss_l2', loss_l2.item(),
                                      global_step)
                    writer.add_scalar('lr',
                                      scheduler.get_lr()[0] * config.batch,
                                      global_step)
                    pbar.set_postfix(
                        **{
                            'loss (batch)': loss.item(),
                            'loss_xy': loss_xy.item(),
                            'loss_wh': loss_wh.item(),
                            'loss_obj': loss_obj.item(),
                            'loss_cls': loss_cls.item(),
                            'loss_l2': loss_l2.item(),
                            'lr': scheduler.get_lr()[0] * config.batch
                        })
                    logging.debug(
                        'Train step_{}: loss : {},loss xy : {},loss wh : {},'
                        'loss obj : {},loss cls : {},loss l2 : {},lr : {}'.
                        format(global_step, loss.item(), loss_xy.item(),
                               loss_wh.item(), loss_obj.item(),
                               loss_cls.item(), loss_l2.item(),
                               scheduler.get_lr()[0] * config.batch))

                pbar.update(images.shape[0])

            if cfg.use_darknet_cfg:
                eval_model = Darknet(cfg.cfgfile, inference=True)
            else:
                eval_model = Yolov4(cfg.pretrained,
                                    n_classes=cfg.classes,
                                    inference=True)
            # eval_model = Yolov4(yolov4conv137weight=None, n_classes=config.classes, inference=True)
            if torch.cuda.device_count() > 1:
                eval_model.load_state_dict(model.module.state_dict())
            else:
                eval_model.load_state_dict(model.state_dict())
            eval_model.to(device)
            evaluator = evaluate(eval_model, val_loader, config, device)
            del eval_model

            stats = evaluator.coco_eval['bbox'].stats
            writer.add_scalar('train/AP', stats[0], global_step)
            writer.add_scalar('train/AP50', stats[1], global_step)
            writer.add_scalar('train/AP75', stats[2], global_step)
            writer.add_scalar('train/AP_small', stats[3], global_step)
            writer.add_scalar('train/AP_medium', stats[4], global_step)
            writer.add_scalar('train/AP_large', stats[5], global_step)
            writer.add_scalar('train/AR1', stats[6], global_step)
            writer.add_scalar('train/AR10', stats[7], global_step)
            writer.add_scalar('train/AR100', stats[8], global_step)
            writer.add_scalar('train/AR_small', stats[9], global_step)
            writer.add_scalar('train/AR_medium', stats[10], global_step)
            writer.add_scalar('train/AR_large', stats[11], global_step)

            if save_cp:
                try:
                    # os.mkdir(config.checkpoints)
                    os.makedirs(config.checkpoints, exist_ok=True)
                    logging.info('Created checkpoint directory')
                except OSError:
                    pass
                save_path = os.path.join(config.checkpoints,
                                         f'{save_prefix}{epoch + 1}.pth')
                torch.save(model.state_dict(), save_path)
                logging.info(f'Checkpoint {epoch + 1} saved !')
                saved_models.append(save_path)
                if len(saved_models) > config.keep_checkpoint_max > 0:
                    model_to_remove = saved_models.popleft()
                    try:
                        os.remove(model_to_remove)
                    except:
                        logging.info(f'failed to remove {model_to_remove}')

    writer.close()
예제 #4
0
if __name__ == "__main__":
    logging = init_logger(log_dir='log')
    cfg = get_args(**Cfg)
    os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    if cfg.use_darknet_cfg:
        model = Darknet(cfg.cfgfile)
    else:
        model = Yolov4(cfg.pretrained, n_classes=cfg.classes)

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to(device=device)

    try:
        train(
            model=model,
            config=cfg,
            epochs=cfg.TRAIN_EPOCHS,
            device=device,
        )
    except KeyboardInterrupt:
        torch.save(model.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
예제 #5
0
def train(model, device, config, epochs=5, batch_size=1, save_cp=True, log_step=20, img_scale=0.5):
    # TODO:加上resume功能,resume需要什么信息?
    # config的所有信息、yolov4-custom.cfg的所有信息,权重,epoch序号,学习率到哪了
    
    
    # 创建dataset
    # config.train_label为data/coins.txt标签文本的路径
    train_dataset = Yolo_dataset(config.train_label, config, train=True)
    val_dataset = Yolo_dataset(config.val_label, config, train=False)

    # 获得dataset的长度
    n_train = len(train_dataset)
    n_val = len(val_dataset)

    # 创建dataloader
    # 当pin_memory=False,num_workers=0(子进程数量为0,即只有主进程)时,正常
    # 当pin_memory=True,num_workers=8时,卡住
    # 当pin_memory=False,num_workers=8时,卡住
    # 当pin_memory=True,num_workers=0时,正常
    # 综上,原因在于num_workers大于0开启多线程导致
    # 经查,dataset加载图片中使用OpenCV,OpenCV某些函数默认也会开多线程,
    # 多线程套多线程,容易导致线程卡住(是否会卡住可能与不同操作系统有关)
    # 解决方法:法一,在dataset的前面import cv2时加上cv2.setNumThreads(0)禁用OpenCV多进程(推荐)
    #          法二,使用PIL加载和预处理图片(不推荐,PIL速度不如OpenCV)
    train_loader = DataLoader(train_dataset, batch_size=config.batch // config.subdivisions, shuffle=True,
                              num_workers=8, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch // config.subdivisions, shuffle=False,
                              num_workers=8, pin_memory=True, drop_last=False, collate_fn=val_collate)
                            
    if config.only_evaluate or config.evaluate_when_train:
        tgtFile = makeTgtJson(val_loader, config.categories)

    writer = SummaryWriter(log_dir=config.TRAIN_TENSORBOARD_DIR,
                           filename_suffix=f'OPT_{config.TRAIN_OPTIMIZER}_LR_{config.learning_rate}_BS_{config.batch}_Sub_{config.subdivisions}_Size_{config.width}',
                           comment=f'OPT_{config.TRAIN_OPTIMIZER}_LR_{config.learning_rate}_BS_{config.batch}_Sub_{config.subdivisions}_Size_{config.width}')
    
    # 计算迭代次数的最大值
    max_itr = config.TRAIN_EPOCHS * n_train
    
    # 迭代次数的全局计数器
    global_step = 0

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {config.batch}
        Subdivisions:    {config.subdivisions}
        Learning rate:   {config.learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images size:     {config.width}
        Optimizer:       {config.TRAIN_OPTIMIZER}
        Dataset classes: {config.classes}
        Train label path:{config.train_label}
        Pretrained:      {config.pretrainedWeight is not None or config.Pretrained is not None}
    ''')
    if config.only_evaluate:
        if config.use_darknet_cfg:
            eval_model = Darknet(config.cfgfile)
        else:
            raise NotImplementedError
        if torch.cuda.device_count() > 1:
            eval_model.load_state_dict(model.module.state_dict())
        else:
            eval_model.load_state_dict(model.state_dict())
        eval_model.to(device)
        eval_model.eval()
        resFile = evaluate(eval_model, config.val_label, config.dataset_dir, device==torch.device("cuda"))
        if resFile is None:
            debugPrint("detect 0 boxes in the val set")
            return
        cocoEvaluate(tgtFile, resFile)
        return

    # learning rate setup
    # 自定义的学习率调整函数,先递增,然后阶梯性降低
    def burnin_schedule(i):
        # i表示iter,而不是epoch
        if i < config.burn_in:  # 按4次方递增阶段
            # factor表示乘在学习率上的倍数
            factor = pow(i / config.burn_in, 4)
        elif i < config.steps[0]:  # 第一阶段
            factor = 1.0
        elif i < config.steps[1]:  # 第二阶段
            factor = 0.1
        else:  # 第三阶段
            factor = 0.01
        return factor

    if config.TRAIN_OPTIMIZER.lower() == 'adam':  # 默认是adam
        optimizer = optim.Adam(
            model.parameters(),
            lr=config.learning_rate / config.batch,  # 学习率的实际值是设置值/batch_size
            betas=(0.9, 0.999),  # adam的特殊参数,一般用默认即可
            eps=1e-08,  # adam的特殊参数,一般用默认即可
        )
    elif config.TRAIN_OPTIMIZER.lower() == 'sgd':
        optimizer = optim.SGD(
            params=model.parameters(),
            lr=config.learning_rate / config.batch,
            momentum=config.momentum,
            weight_decay=config.decay,
        )

    # pytorch调整学习率的专用接口
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)

    # 计算loss的对象,这个模块是在yolo网络后专门求解loss的(yolo主网络只负责接收图片,然后输出三路张量),这个模块不需要权重等参数
    criterion = Yolo_loss(device=device, batch=config.batch // config.subdivisions, n_classes=config.classes)

    save_prefix = 'Yolov4_epoch'
    saved_models = deque()
    for epoch in range(epochs):
        epoch_loss = 0
        epoch_step = 0
        model.train()
        logging.info("===Train===")
        for i, batch in enumerate(train_loader):
            global_step += 1
            epoch_step += 1
            images = batch[0]
            bboxes = batch[1]

            images = images.to(device=device, dtype=torch.float32)
            bboxes = bboxes.to(device=device)

            bboxes_pred = model(images)
            loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2 = criterion(bboxes_pred, bboxes)
            loss.backward()

            epoch_loss += loss.item()

            if global_step % config.subdivisions == 0:
                optimizer.step()
                scheduler.step()
                model.zero_grad()
            
            logging.info("Epoch:[{:3}/{}],step:[{:3}/{}],total loss:{:.2f}|lr:{:.5f}".format(epoch + 1, epochs, i + 1, len(train_loader), loss.item(), scheduler.get_last_lr()[0]))

            if global_step % (log_step * config.subdivisions) == 0:  # log_step默认为20,这里指的是迭代次数
                
                writer.add_scalar('train/Loss', loss.item(), global_step)
                writer.add_scalar('train/loss_xy', loss_xy.item(), global_step)
                writer.add_scalar('train/loss_wh', loss_wh.item(), global_step)
                writer.add_scalar('train/loss_obj', loss_obj.item(), global_step)
                writer.add_scalar('train/loss_cls', loss_cls.item(), global_step)
                writer.add_scalar('train/loss_l2', loss_l2.item(), global_step)
                writer.add_scalar('lr', scheduler.get_last_lr()[0] * config.batch, global_step)
                
                logging.debug('Train step_{}: loss : {},loss xy : {},loss wh : {},'
                            'loss obj : {},loss cls : {},loss l2 : {},lr : {}'
                            .format(global_step, loss.item(), loss_xy.item(),
                                    loss_wh.item(), loss_obj.item(),
                                    loss_cls.item(), loss_l2.item(),
                                    scheduler.get_last_lr()[0] * config.batch))
        if save_cp:  # True
            # 创建checkpoints文件夹
            if not os.path.exists(config.checkpoints):
                os.makedirs(config.checkpoints, exist_ok=True)  # exist_ok=True表示可以接受已经存在该文件夹,当exist_ok=False时文件夹存在会抛出错误
                logging.info('Created checkpoint directory')
            save_path = os.path.join(config.checkpoints, f'{save_prefix}{epoch + 1}.weights')                
            # 考虑torch.nn.DataParallel特殊情况
            if torch.cuda.device_count() > 1:
                model.module.save_weights(save_path)
            else:
                model.save_weights(save_path)                
            logging.info(f'Checkpoint {epoch + 1} saved !')
            # 只保留最新keep_checkpoint_max个checkpoint,自动删除较早的checkpoint
            saved_models.append(save_path)
            if len(saved_models) > config.keep_checkpoint_max > 0:
                model_to_remove = saved_models.popleft()
                try:
                    os.remove(model_to_remove)
                except:
                    logging.info(f'failed to remove {model_to_remove}')

        if config.evaluate_when_train:
            try:
                model.eval()
                resFile = evaluate(model, config.val_label, config.dataset_dir, device==torch.device("cuda"), config.width, config.height)
                if resFile is None:
                    continue
                stats = cocoEvaluate(tgtFile, resFile)

                logging.info("===Val===")
                logging.info("Epoch:[{:3}/{}],AP:{:.3f}|AP50:{:.3f}|AP75:{:.3f}|APs:{:.3f}|APm:{:.3f}|APl:{:.3f}".format(
                    epoch + 1, epochs, stats[0], stats[1], stats[2], stats[3], stats[4], stats[5]))
                logging.info("Epoch:[{:3}/{}],AR1:{:.3f}|AR10:{:.3f}|AR100:{:.3f}|ARs:{:.3f}|ARm:{:.3f}|ARl:{:.3f}".format(
                    epoch + 1, epochs, stats[6], stats[7], stats[8], stats[9], stats[10], stats[11]))


                writer.add_scalar('train/AP', stats[0], global_step)
                writer.add_scalar('train/AP50', stats[1], global_step)
                writer.add_scalar('train/AP75', stats[2], global_step)
                writer.add_scalar('train/AP_small', stats[3], global_step)
                writer.add_scalar('train/AP_medium', stats[4], global_step)
                writer.add_scalar('train/AP_large', stats[5], global_step)
                writer.add_scalar('train/AR1', stats[6], global_step)
                writer.add_scalar('train/AR10', stats[7], global_step)
                writer.add_scalar('train/AR100', stats[8], global_step)
                writer.add_scalar('train/AR_small', stats[9], global_step)
                writer.add_scalar('train/AR_medium', stats[10], global_step)
                writer.add_scalar('train/AR_large', stats[11], global_step)
            except Exception as e:
                debugPrint("evaluate meets an exception, here is the exception info:")
                traceback.print_exc()
                debugPrint("ignore error in evaluate and continue training")

    writer.close()
예제 #6
0
def train(
    model,
    device,
    config,
    epochs=5,
    save_cp=True,
    log_step=20,
):
    # Get dataloaders
    train_dataset = Yolo_BEV_dataset(config, split="train")
    val_dataset = Yolo_BEV_dataset(config, split="val")

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch // config.subdivisions,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
        drop_last=True,
        collate_fn=collate,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch // config.subdivisions,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
        drop_last=True,
        collate_fn=collate,
    )

    # define summary writer
    writer = SummaryWriter(
        log_dir=config.TRAIN_TENSORBOARD_DIR,
        filename_suffix=
        f"OPT_{config.TRAIN_OPTIMIZER}_LR_{config.learning_rate}_BS_{config.batch}_Sub_{config.subdivisions}_Size_{config.width}",
        comment=
        f"OPT_{config.TRAIN_OPTIMIZER}_LR_{config.learning_rate}_BS_{config.batch}_Sub_{config.subdivisions}_Size_{config.width}",
    )

    # log
    n_train = len(train_dataset)
    n_val = len(val_dataset)
    global_step = 0
    logging.info(f"""Starting training:
        Epochs:          {config.epochs}
        Batch size:      {config.batch}
        Subdivisions:    {config.subdivisions}
        Learning rate:   {config.learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Input height:    {config.height}
        Input width:     {config.width}
        Optimizer:       {config.TRAIN_OPTIMIZER}
        Dataset classes: {config.classes}
    """)

    # learning rate setup
    def burnin_schedule(i):
        if i < config.burn_in:
            factor = pow(i / config.burn_in, 4)
        elif i < config.steps[0]:
            factor = 1.0
        elif i < config.steps[1]:
            factor = 0.1
        else:
            factor = 0.01
        return factor

    # optimizer + scheduler
    if config.TRAIN_OPTIMIZER.lower() == "adam":
        optimizer = optim.Adam(
            model.parameters(),
            lr=config.learning_rate / config.batch,
            betas=(0.9, 0.999),
            eps=1e-08,
        )
    elif config.TRAIN_OPTIMIZER.lower() == "sgd":
        optimizer = optim.SGD(
            params=model.parameters(),
            lr=config.learning_rate / config.batch,
            momentum=config.momentum,
            weight_decay=config.decay,
        )

    # scheduler multiplies learning rate by a factor calculated on epoch
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)

    # loss function
    criterion = Yolo_loss(
        cfg=config,
        device=device,
    )

    # start training
    save_prefix = "Yolov4_BEV_flat_epoch"
    saved_models = deque()
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        epoch_step = 0

        with tqdm(total=n_train,
                  desc=f"Epoch {epoch + 1}/{epochs}",
                  unit="img",
                  ncols=75) as pbar:
            for i, batch in enumerate(train_loader):
                # get batch
                global_step += 1
                epoch_step += 1
                images = batch[0].float().to(device=device)
                labels = batch[1]

                # compute loss
                preds = model(images)[0]
                loss, loss_xy, loss_wl, loss_rot, loss_obj, loss_noobj = criterion(
                    preds, labels)
                loss.backward()

                epoch_loss += loss.item()

                # update weights
                if global_step % config.subdivisions == 0:
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()

                # log
                if global_step % (log_step * config.subdivisions) == 0:
                    writer.add_scalar("train/Loss", loss.item(), global_step)
                    writer.add_scalar("train/loss_xy", loss_xy.item(),
                                      global_step)
                    writer.add_scalar("train/loss_wl", loss_wl.item(),
                                      global_step)
                    writer.add_scalar("train/loss_rot", loss_rot.item(),
                                      global_step)
                    writer.add_scalar("train/loss_obj", loss_obj.item(),
                                      global_step)
                    writer.add_scalar("train/loss_noobj", loss_noobj.item(),
                                      global_step)
                    writer.add_scalar("lr",
                                      scheduler.get_lr()[0] * config.batch,
                                      global_step)
                    pbar.set_postfix({
                        "loss (batch)":
                        loss.item(),
                        "loss_xy":
                        loss_xy.item(),
                        "loss_wl":
                        loss_wl.item(),
                        "loss_rot":
                        loss_rot.item(),
                        "loss_obj":
                        loss_obj.item(),
                        "loss_noobj":
                        loss_noobj.item(),
                        "lr":
                        scheduler.get_lr()[0] * config.batch,
                    })
                    logging.debug(
                        "Train step_{}: loss : {},loss xy : {},loss wl : {},"
                        "loss rot : {},loss obj : {},loss noobj : {},lr : {}".
                        format(
                            global_step,
                            loss.item(),
                            loss_xy.item(),
                            loss_wl.item(),
                            loss_rot.item(),
                            loss_obj.item(),
                            loss_noobj.item(),
                            scheduler.get_lr()[0] * config.batch,
                        ))

                pbar.update(images.shape[0])

            # evaluate models
            min_eval_loss = math.inf
            if epoch % 2 == 0:
                eval_model = Darknet(cfg.cfgfile,
                                     inference=True,
                                     model_type="BEV_flat")
                if torch.cuda.device_count() > 1:
                    eval_model.load_state_dict(model.module.state_dict())
                else:
                    eval_model.load_state_dict(model.state_dict())
                eval_model.to(device)
                eval_model.eval()

                eval_loss = 0.0
                eval_loss_xy = 0.0
                eval_loss_wl = 0.0
                eval_loss_rot = 0.0
                eval_loss_obj = 0.0
                eval_loss_noobj = 0.0
                with tqdm(total=n_val,
                          desc=f"Eval {(epoch + 1) // 2}",
                          unit="img",
                          ncols=75) as epbar:
                    for i, batch in enumerate(val_loader):
                        # get batch
                        global_step += 1
                        epoch_step += 1
                        images = batch[0].float().to(device=device)
                        labels = batch[1]

                        # compute loss
                        labels_pred = model(images)[0]
                        loss, loss_xy, loss_wl, loss_rot, loss_obj, loss_noobj = criterion(
                            labels_pred, labels)
                        eval_loss += loss.item()
                        eval_loss_xy += loss_xy.item()
                        eval_loss_wl += loss_wl.item()
                        eval_loss_rot += loss_rot.item()
                        eval_loss_rot += loss_obj.item()
                        eval_loss_noobj += loss_noobj.item()

                        epbar.update(images.shape[0])

                # log
                logging.debug(
                    "Val step_{}: loss : {},loss xy : {},loss wl : {},"
                    "loss rot : {},loss obj : {},loss noobj : {},lr : {}".
                    format(
                        global_step,
                        eval_loss.item(),
                        eval_loss_xy.item(),
                        eval_loss_wl.item(),
                        eval_loss_rot.item(),
                        eval_loss_obj.item(),
                        eval_loss_noobj.item(),
                        scheduler.get_lr()[0] * config.batch,
                    ))

                del eval_model

            # save checkpoint
            if save_cp and eval_loss < min_eval_loss:
                min_eval_loss = eval_loss
                try:
                    os.makedirs(config.checkpoints, exist_ok=True)
                    logging.info("Created checkpoint directory")
                except OSError:
                    pass
                save_path = os.path.join(config.checkpoints,
                                         f"{save_prefix}{epoch + 1}.pth")
                torch.save(model.state_dict(), save_path)
                logging.info(f"Checkpoint {epoch + 1} saved !")
                saved_models.append(save_path)
                if len(saved_models) > config.keep_checkpoint_max > 0:
                    model_to_remove = saved_models.popleft()
                    try:
                        os.remove(model_to_remove)
                    except:
                        logging.info(f"failed to remove {model_to_remove}")

    writer.close()