Exemple #1
0
def run():
    args = parse_args()

    # use cuda
    if args.cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    input_size = [args.input_size, args.input_size]

    # load net
    if args.version == 'centernet':
        from models.centernet import CenterNet
        net = CenterNet(device, 
                        input_size=input_size, 
                        num_classes=80, 
                        conf_thresh=args.conf_thresh, 
                        nms_thresh=args.nms_thresh, 
                        use_nms=args.use_nms)

    net.load_state_dict(torch.load(args.trained_model, map_location='cuda'))
    net.to(device).eval()
    print('Finished loading model!')

    # run
    if args.mode == 'camera':
        detect(net, device, BaseTransform(net.input_size), 
                    thresh=args.visual_threshold, mode=args.mode)
    elif args.mode == 'image':
        detect(net, device, BaseTransform(net.input_size), 
                    thresh=args.visual_threshold, mode=args.mode, path_to_img=args.path_to_img)
    elif args.mode == 'video':
        detect(net, device, BaseTransform(net.input_size),
                    thresh=args.visual_threshold, mode=args.mode, path_to_vid=args.path_to_vid, path_to_save=args.path_to_saveVid)
Exemple #2
0
def test():
    # get device
    if args.cuda:
        print('use cuda')
        cudnn.benchmark = True
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    # load net
    num_classes = len(VOC_CLASSES)
    testset = VOCDetection(args.voc_root, [('2007', 'test')], None,
                           VOCAnnotationTransform())

    cfg = config.voc_cfg
    if args.version == 'centernet':
        from models.centernet import CenterNet
        net = CenterNet(device,
                        input_size=cfg['min_dim'],
                        num_classes=num_classes)

    net.load_state_dict(torch.load(args.trained_model, map_location=device))
    net.to(device).eval()
    print('Finished loading model!')

    # evaluation
    test_net(net,
             device,
             testset,
             BaseTransform(net.input_size,
                           mean=(0.406, 0.456, 0.485),
                           std=(0.225, 0.224, 0.229)),
             thresh=args.visual_threshold)
    def __init__(self, cfg):
        self.cfg = cfg
        #print(self.cfg.Val.threshold)
        model = CenterNet(cfg).cuda(cfg.Distributed.gpu_id)

        self.optimizer = optim.Adam(model.parameters(), lr=cfg.Train.lr)

        self.lr_sch = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=cfg.Train.lr_milestones, gamma=0.1)

        self.training_loader, self.validation_loader = make_dataloader(
            cfg, collate_fn='ctnet')

        super(CenterNetOperator, self).__init__(cfg=self.cfg,
                                                model=model,
                                                lr_sch=self.lr_sch)

        # TODO: change it to our class
        self.focal_loss = FocalLossHM()
        self.l1_loss = RegL1Loss()

        self.main_proc_flag = cfg.Distributed.gpu_id == 0
Exemple #4
0
def run():
    args = parse_args()

    if args.cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    if args.setup == 'VOC':
        print('use VOC style')
        cfg = config.voc_cfg
        num_classes = 20
    elif args.setup == 'COCO':
        print('use COCO style')
        cfg = config.coco_cfg
        num_classes = 80
    else:
        print('Only support VOC and COCO !!!')
        exit(0)

    if args.version == 'centernet':
        from models.centernet import CenterNet
        net = CenterNet(device, input_size=cfg['min_dim'], num_classes=num_classes, use_nms=True)

    
    net.load_state_dict(torch.load(args.trained_model, map_location=device))
    net.to(device).eval()
    print('Finished loading model!')

    # run
    if args.mode == 'camera':
        detect(net, device, BaseTransform(net.input_size, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)), 
                    thresh=args.vis_thresh, mode=args.mode, setup=args.setup)
    elif args.mode == 'image':
        detect(net, device, BaseTransform(net.input_size, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)), 
                    thresh=args.vis_thresh, mode=args.mode, path_to_img=args.path_to_img, setup=args.setup)
    elif args.mode == 'video':
        detect(net, device, BaseTransform(net.input_size, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)),
                    thresh=args.vis_thresh, mode=args.mode, path_to_vid=args.path_to_vid, path_to_save=args.path_to_saveVid, setup=args.setup)
Exemple #5
0
def test():
    # get device
    if args.cuda:
        print('use cuda')
        cudnn.benchmark = True
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    # load net
    num_classes = 80
    if args.dataset == 'COCO':
        cfg = config.coco_cfg
        testset = COCODataset(
                    data_dir=args.dataset_root,
                    json_file='instances_val2017.json',
                    name='val2017',
                    img_size=cfg['min_dim'][0],
                    debug=args.debug)
    elif args.dataset == 'VOC':
        cfg = config.voc_cfg
        testset = VOCDetection(VOC_ROOT, [('2007', 'test')], None, VOCAnnotationTransform())


    if args.version == 'centernet':
        from models.centernet import CenterNet
        net = CenterNet(device, input_size=cfg['min_dim'], num_classes=num_classes)

    net.load_state_dict(torch.load(args.trained_model, map_location='cuda'))
    net.to(device).eval()
    print('Finished loading model!')

    # evaluation
    test_net(net, device, testset,
             BaseTransform(net.input_size, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)),
             thresh=args.visual_threshold)
Exemple #6
0
def train():
    args = parse_args()
    data_dir = args.dataset_root

    path_to_save = os.path.join(args.save_folder, args.version)
    os.makedirs(path_to_save, exist_ok=True)
    
    cfg = coco_cfg

    if args.cuda:
        print('use cuda')
        cudnn.benchmark = True
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    input_size = cfg['min_dim']
    dataset = COCODataset(
                data_dir=data_dir,
                img_size=cfg['min_dim'],
                transform=SSDAugmentation(cfg['min_dim'], mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)),
                debug=args.debug)

    # build model
    if args.version == 'centernet':
        from models.centernet import CenterNet
        
        net = CenterNet(device, input_size=input_size, num_classes=args.num_classes, trainable=True)
        print('Let us train centernet on the COCO dataset ......')

    else:
        print('Unknown version !!!')
        exit()

    
    print("Setting Arguments.. : ", args)
    print("----------------------------------------------------------")
    print('Loading the MSCOCO dataset...')
    print('Training model on:', dataset.name)
    print('The dataset size:', len(dataset))
    print("----------------------------------------------------------")


    # use tfboard
    if args.tfboard:
        print('use tensorboard')
        from torch.utils.tensorboard import SummaryWriter
        c_time = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
        log_path = os.path.join('log/coco/', args.version, c_time)
        os.makedirs(log_path, exist_ok=True)

        writer = SummaryWriter(log_path)

    
    model = net
    model.to(device).train()

    dataloader = torch.utils.data.DataLoader(
                    dataset, 
                    batch_size=args.batch_size, 
                    shuffle=True, 
                    collate_fn=detection_collate,
                    num_workers=args.num_workers)

    evaluator = COCOAPIEvaluator(
                    data_dir=data_dir,
                    img_size=cfg['min_dim'],
                    device=device,
                    transform=BaseTransform(cfg['min_dim'], mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229))
                    )

    # optimizer setup
    base_lr = args.lr
    tmp_lr = base_lr
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
                                            weight_decay=args.weight_decay)

    max_epoch = cfg['max_epoch']
    epoch_size = len(dataset) // args.batch_size

    # start training loop
    t0 = time.time()

    for epoch in range(max_epoch):

        # use cos lr
        if args.cos and epoch > 20 and epoch <= max_epoch - 20:
            # use cos lr
            tmp_lr = 0.00001 + 0.5*(base_lr-0.00001)*(1+math.cos(math.pi*(epoch-20)*1./ (max_epoch-20)))
            set_lr(optimizer, tmp_lr)

        elif args.cos and epoch > max_epoch - 20:
            tmp_lr = 0.00001
            set_lr(optimizer, tmp_lr)
        
        # use step lr
        else:
            if epoch in cfg['lr_epoch']:
                tmp_lr = tmp_lr * 0.1
                set_lr(optimizer, tmp_lr)

        for iter_i, (images, targets) in enumerate(dataloader):
            # WarmUp strategy for learning rate
            if not args.no_warm_up:
                if epoch < args.wp_epoch:
                    tmp_lr = base_lr * pow((iter_i+epoch*epoch_size)*1. / (args.wp_epoch*epoch_size), 4)
                    # tmp_lr = 1e-6 + (base_lr-1e-6) * (iter_i+epoch*epoch_size) / (epoch_size * (args.wp_epoch))
                    set_lr(optimizer, tmp_lr)

                elif epoch == args.wp_epoch and iter_i == 0:
                    tmp_lr = base_lr
                    set_lr(optimizer, tmp_lr)
        

            targets = [label.tolist() for label in targets]
            targets = tools.gt_creator(input_size, net.stride, args.num_classes, targets)


            # to device
            images = images.to(device)
            targets = torch.tensor(targets).float().to(device)

            # forward and loss
            cls_loss, txty_loss, twth_loss, total_loss = model(images, target=targets)
                     
            # backprop and update
            total_loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if iter_i % 10 == 0:
                if args.tfboard:
                    # viz loss
                    writer.add_scalar('class loss', cls_loss.item(), iter_i + epoch * epoch_size)
                    writer.add_scalar('txty loss',  txty_loss.item(), iter_i + epoch * epoch_size)
                    writer.add_scalar('twth loss',  twth_loss.item(), iter_i + epoch * epoch_size)
                    writer.add_scalar('total loss', total_loss.item(), iter_i + epoch * epoch_size)
                
                t1 = time.time()
                print('[Epoch %d/%d][Iter %d/%d][lr %.6f]'
                    '[Loss: cls %.2f || txty %.2f || twth %.2f ||total %.2f || size %d || time: %.2f]'
                        % (epoch+1, max_epoch, iter_i, epoch_size, tmp_lr,
                            cls_loss.item(), txty_loss.item(), twth_loss.item(), total_loss.item(), input_size, t1-t0),
                        flush=True)

                t0 = time.time()


        # COCO evaluation
        if (epoch + 1) % args.eval_epoch == 0:
            model.trainable = False
            # evaluate
            ap50_95, ap50 = evaluator.evaluate(model)
            print('ap50 : ', ap50)
            print('ap50_95 : ', ap50_95)
            # convert to training mode.
            model.trainable = True
            model.train()
            if args.tfboard:
                writer.add_scalar('val/COCOAP50', ap50, epoch + 1)
                writer.add_scalar('val/COCOAP50_95', ap50_95, epoch + 1)

        if (epoch + 1) % 10 == 0:
            print('Saving state, epoch:', epoch + 1)
            torch.save(model.state_dict(), os.path.join(path_to_save, 
                        args.version + '_' + repr(epoch + 1) + '.pth')
                        )  
Exemple #7
0
def train():
    args = parse_args()

    path_to_save = os.path.join(args.save_folder, args.dataset, args.version)
    os.makedirs(path_to_save, exist_ok=True)

    # cuda
    if args.cuda:
        print('use cuda')
        cudnn.benchmark = True
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    # mosaic augmentation
    if args.mosaic:
        print('use Mosaic Augmentation ...')

    # multi-scale
    if args.multi_scale:
        print('use the multi-scale trick ...')
        train_size = [640, 640]
        val_size = [512, 512]
    else:
        train_size = [512, 512]
        val_size = [512, 512]

    cfg = train_cfg
    # dataset and evaluator
    print("Setting Arguments.. : ", args)
    print("----------------------------------------------------------")
    print('Loading the dataset...')

    if args.dataset == 'voc':
        data_dir = VOC_ROOT
        num_classes = 20
        dataset = VOCDetection(root=data_dir, img_size=train_size[0],
                                transform=SSDAugmentation(train_size),
                                mosaic=args.mosaic
                                )

        evaluator = VOCAPIEvaluator(data_root=data_dir,
                                    img_size=val_size,
                                    device=device,
                                    transform=BaseTransform(val_size),
                                    labelmap=VOC_CLASSES
                                    )

    elif args.dataset == 'coco':
        data_dir = coco_root
        num_classes = 80
        dataset = COCODataset(
                    data_dir=data_dir,
                    img_size=train_size[0],
                    transform=SSDAugmentation(train_size),
                    debug=args.debug,
                    mosaic=args.mosaic
                    )


        evaluator = COCOAPIEvaluator(
                        data_dir=data_dir,
                        img_size=val_size,
                        device=device,
                        transform=BaseTransform(val_size)
                        )
    
    else:
        print('unknow dataset !! Only support voc and coco !!')
        exit(0)
    
    print('Training model on:', dataset.name)
    print('The dataset size:', len(dataset))
    print("----------------------------------------------------------")

    # dataloader
    dataloader = torch.utils.data.DataLoader(
                    dataset, 
                    batch_size=args.batch_size, 
                    shuffle=True, 
                    collate_fn=detection_collate,
                    num_workers=args.num_workers,
                    pin_memory=True
                    )

    # build model
    if args.version == 'centernet':
        from models.centernet import CenterNet
        
        net = CenterNet(device, input_size=train_size, num_classes=num_classes, trainable=True)
        print('Let us train centernet on the %s dataset ......' % (args.dataset))

    else:
        print('Unknown version !!!')
        exit()

    model = net
    model.to(device).train()

    # use tfboard
    if args.tfboard:
        print('use tensorboard')
        from torch.utils.tensorboard import SummaryWriter
        c_time = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
        log_path = os.path.join('log/coco/', args.version, c_time)
        os.makedirs(log_path, exist_ok=True)

        writer = SummaryWriter(log_path)
    
    # keep training
    if args.resume is not None:
        print('keep training model: %s' % (args.resume))
        model.load_state_dict(torch.load(args.resume, map_location=device))

    # optimizer setup
    base_lr = args.lr
    tmp_lr = base_lr
    optimizer = optim.SGD(model.parameters(), 
                            lr=args.lr, 
                            momentum=args.momentum,
                            weight_decay=args.weight_decay
                            )

    max_epoch = cfg['max_epoch']
    epoch_size = len(dataset) // args.batch_size

    # start training loop
    t0 = time.time()

    for epoch in range(args.start_epoch, max_epoch):

        # use cos lr
        if args.cos and epoch > 20 and epoch <= max_epoch - 20:
            # use cos lr
            tmp_lr = 0.00001 + 0.5*(base_lr-0.00001)*(1+math.cos(math.pi*(epoch-20)*1./ (max_epoch-20)))
            set_lr(optimizer, tmp_lr)

        elif args.cos and epoch > max_epoch - 20:
            tmp_lr = 0.00001
            set_lr(optimizer, tmp_lr)
        
        # use step lr
        else:
            if epoch in cfg['lr_epoch']:
                tmp_lr = tmp_lr * 0.1
                set_lr(optimizer, tmp_lr)
    

        for iter_i, (images, targets) in enumerate(dataloader):
            # WarmUp strategy for learning rate
            if not args.no_warm_up:
                if epoch < args.wp_epoch:
                    tmp_lr = base_lr * pow((iter_i+epoch*epoch_size)*1. / (args.wp_epoch*epoch_size), 4)
                    # tmp_lr = 1e-6 + (base_lr-1e-6) * (iter_i+epoch*epoch_size) / (epoch_size * (args.wp_epoch))
                    set_lr(optimizer, tmp_lr)

                elif epoch == args.wp_epoch and iter_i == 0:
                    tmp_lr = base_lr
                    set_lr(optimizer, tmp_lr)
        
            # to device
            images = images.to(device)

            # multi-scale trick
            if iter_i % 10 == 0 and iter_i > 0 and args.multi_scale:
                # randomly choose a new size
                size = random.randint(10, 19) * 32
                train_size = [size, size]
                model.set_grid(train_size)
            if args.multi_scale:
                # interpolate
                images = torch.nn.functional.interpolate(images, size=train_size, mode='bilinear', align_corners=False)
            
            # make train label
            targets = [label.tolist() for label in targets]
            targets = tools.gt_creator(train_size, net.stride, args.num_classes, targets)
            targets = torch.tensor(targets).float().to(device)

            # forward and loss
            cls_loss, txty_loss, twth_loss, total_loss = model(images, target=targets)

            # backprop
            total_loss.backward()        
            optimizer.step()
            optimizer.zero_grad()

            if iter_i % 10 == 0:
                if args.tfboard:
                    # viz loss
                    writer.add_scalar('class loss', cls_loss.item(), iter_i + epoch * epoch_size)
                    writer.add_scalar('txty loss',  txty_loss.item(), iter_i + epoch * epoch_size)
                    writer.add_scalar('twth loss',  twth_loss.item(), iter_i + epoch * epoch_size)
                    writer.add_scalar('total loss', total_loss.item(), iter_i + epoch * epoch_size)
                
                t1 = time.time()
                print('[Epoch %d/%d][Iter %d/%d][lr %.6f]'
                    '[Loss: cls %.2f || txty %.2f || twth %.2f ||total %.2f || size %d || time: %.2f]'
                        % (epoch+1, max_epoch, iter_i, epoch_size, tmp_lr,
                            cls_loss.item(), txty_loss.item(), twth_loss.item(), total_loss.item(), train_size[0], t1-t0),
                        flush=True)

                t0 = time.time()

        # evaluation
        if (epoch) % args.eval_epoch == 0:
            model.trainable = False
            model.set_grid(val_size)
            model.eval()

            # evaluate
            evaluator.evaluate(model)

            # convert to training mode.
            model.trainable = True
            model.set_grid(train_size)
            model.train()

        # save model
        if (epoch + 1) % 10 == 0:
            print('Saving state, epoch:', epoch + 1)
            torch.save(model.state_dict(), os.path.join(path_to_save, 
                        args.version + '_' + repr(epoch + 1) + '.pth')
                        )  
Exemple #8
0
    if args.cuda:
        print('use cuda')
        torch.backends.cudnn.benchmark = True
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    # input size
    input_size = [args.input_size, args.input_size]

    # load net
    if args.version == 'centernet':
        from models.centernet import CenterNet
        net = CenterNet(device, 
                        input_size=input_size, 
                        num_classes=num_classes, 
                        backbone=args.backbone,
                        use_nms=args.use_nms)

    # load net
    net.load_state_dict(torch.load(args.trained_model, map_location='cuda'))
    net.eval()
    print('Finished loading model!')
    net = net.to(device)
    
    # evaluation
    with torch.no_grad():
        if args.dataset == 'voc':
            voc_test(net, device, input_size)
        elif args.dataset == 'coco-val':
            coco_test(net, device, input_size, test=False)
Exemple #9
0
        class_indexs = coco_class_index
        num_classes = 80
        dataset = COCODataset(data_dir=coco_root,
                              json_file='instances_val2017.json',
                              name='val2017',
                              img_size=input_size[0])

    class_colors = [(np.random.randint(255), np.random.randint(255),
                     np.random.randint(255)) for _ in range(num_classes)]

    # load net
    if args.version == 'centernet':
        from models.centernet import CenterNet
        net = CenterNet(device,
                        input_size=input_size,
                        num_classes=num_classes,
                        conf_thresh=args.conf_thresh,
                        nms_thresh=args.nms_thresh,
                        use_nms=args.use_nms)

    net.load_state_dict(torch.load(args.trained_model, map_location=device))
    net.to(device).eval()
    print('Finished loading model!')

    # evaluation
    test(net=net,
         device=device,
         testset=dataset,
         transform=BaseTransform(input_size),
         thresh=args.visual_threshold,
         class_colors=class_colors,
         class_names=class_names,
Exemple #10
0
    print('ap50_95 : ', ap50_95)


if __name__ == '__main__':
    global cfg

    cfg = coco_cfg
    if args.cuda:
        print('use cuda')
        torch.backends.cudnn.benchmark = True
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    if args.version == 'centernet':
        from models.centernet import CenterNet
        model = CenterNet(device,
                          input_size=cfg['min_dim'],
                          num_classes=args.num_classes)

    else:
        print('Unknown Version !!!')
        exit()

    # load model
    model.load_state_dict(torch.load(args.trained_model, map_location=device))
    model.eval().to(device)
    print('Finished loading model!')

    test(model, device)
Exemple #11
0
def create_model(cfg, arch):
    model = CenterNet(cfg, arch)
    return model
Exemple #12
0
                    start_lr *= 0.1
            for param_group in optimizer.param_groups:
                param_group['lr'] = start_lr
            print('Resumed optimizer with start lr', start_lr)
        else:
            print('No optimizer parameters in checkpoint.')
    if optimizer is not None:
        return model, optimizer, start_epoch
    else:
        return model


def save_model(path, epoch, model, optimizer=None):
    if isinstance(model, torch.nn.DataParallel):
        state_dict = model.module.state_dict()
    else:
        state_dict = model.state_dict()
    data = {'epoch': epoch, 'state_dict': state_dict}
    if not (optimizer is None):
        data['optimizer'] = optimizer.state_dict()
    torch.save(data, path)


if __name__ == '__main__':
    from torchsummary import summary
    from config import Config
    cfg = Config()

    model = CenterNet(cfg, 'litnet')
    summary(model, (3, 512, 512), device='cpu')
Exemple #13
0
def train():
    args = parse_args()

    path_to_save = os.path.join(args.save_folder, args.version)
    os.makedirs(path_to_save, exist_ok=True)

    cfg = voc_cfg

    if args.cuda:
        print('use cuda')
        cudnn.benchmark = True
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    input_size = cfg['min_dim']
    dataset = VOCDetection(root=args.dataset_root,
                           transform=SSDAugmentation(cfg['min_dim'],
                                                     mean=(0.406, 0.456,
                                                           0.485),
                                                     std=(0.225, 0.224,
                                                          0.229)))

    # build model
    if args.version == 'centernet':
        from models.centernet import CenterNet

        net = CenterNet(device,
                        input_size=input_size,
                        num_classes=args.num_classes,
                        trainable=True)
        print('Let us train centernet on the VOC0712 dataset ......')

    else:
        print('Unknown version !!!')
        exit()

    # finetune the model trained on COCO
    if args.resume is not None:
        print('finetune COCO trained ')
        net.load_state_dict(torch.load(args.resume, map_location=device),
                            strict=False)

    # use tfboard
    if args.tfboard:
        print('use tensorboard')
        from torch.utils.tensorboard import SummaryWriter
        c_time = time.strftime('%Y-%m-%d %H:%M:%S',
                               time.localtime(time.time()))
        log_path = os.path.join('log/voc/', args.version, c_time)
        os.makedirs(log_path, exist_ok=True)

        writer = SummaryWriter(log_path)

    print(
        "----------------------------------------Object Detection--------------------------------------------"
    )
    model = net
    model.to(device)

    base_lr = args.lr
    tmp_lr = base_lr
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    # loss counters
    print("----------------------------------------------------------")
    print("Let's train OD network !")
    print('Training on:', dataset.name)
    print('The dataset size:', len(dataset))
    print("----------------------------------------------------------")

    epoch_size = len(dataset) // args.batch_size
    max_epoch = cfg['max_epoch']

    data_loader = data.DataLoader(dataset,
                                  args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  collate_fn=detection_collate,
                                  pin_memory=True)
    # create batch iterator
    t0 = time.time()

    # start training
    for epoch in range(max_epoch):

        # use cos lr
        if args.cos and epoch > 20 and epoch <= max_epoch - 20:
            # use cos lr
            tmp_lr = 0.00001 + 0.5 * (base_lr - 0.00001) * (
                1 + math.cos(math.pi * (epoch - 20) * 1. / (max_epoch - 20)))
            set_lr(optimizer, tmp_lr)

        elif args.cos and epoch > max_epoch - 20:
            tmp_lr = 0.00001
            set_lr(optimizer, tmp_lr)

        # use step lr
        else:
            if epoch in cfg['lr_epoch']:
                tmp_lr = tmp_lr * 0.1
                set_lr(optimizer, tmp_lr)

        for iter_i, (images, targets) in enumerate(data_loader):
            # WarmUp strategy for learning rate
            if not args.no_warm_up:
                if epoch < args.wp_epoch:
                    tmp_lr = base_lr * pow((iter_i + epoch * epoch_size) * 1. /
                                           (args.wp_epoch * epoch_size), 4)
                    # tmp_lr = 1e-6 + (base_lr-1e-6) * (iter_i+epoch*epoch_size) / (epoch_size * (args.wp_epoch))
                    set_lr(optimizer, tmp_lr)

                elif epoch == args.wp_epoch and iter_i == 0:
                    tmp_lr = base_lr
                    set_lr(optimizer, tmp_lr)

            targets = [label.tolist() for label in targets]
            # vis_data(images, targets, input_size)

            # make train label
            targets = tools.gt_creator(input_size, net.stride,
                                       args.num_classes, targets)

            # vis_heatmap(targets)

            # to device
            images = images.to(device)
            targets = torch.tensor(targets).float().to(device)

            # forward and loss
            cls_loss, txty_loss, twth_loss, total_loss = model(images,
                                                               target=targets)

            # backprop and update
            total_loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if iter_i % 10 == 0:
                if args.tfboard:
                    # viz loss
                    writer.add_scalar('class loss', cls_loss.item(),
                                      iter_i + epoch * epoch_size)
                    writer.add_scalar('txty loss', txty_loss.item(),
                                      iter_i + epoch * epoch_size)
                    writer.add_scalar('twth loss', twth_loss.item(),
                                      iter_i + epoch * epoch_size)
                    writer.add_scalar('total loss', total_loss.item(),
                                      iter_i + epoch * epoch_size)

                t1 = time.time()
                print(
                    '[Epoch %d/%d][Iter %d/%d][lr %.6f]'
                    '[Loss: cls %.2f || txty %.2f || twth %.2f ||total %.2f || size %d || time: %.2f]'
                    % (epoch + 1, max_epoch, iter_i, epoch_size, tmp_lr,
                       cls_loss.item(), txty_loss.item(), twth_loss.item(),
                       total_loss.item(), input_size, t1 - t0),
                    flush=True)

                t0 = time.time()

        if (epoch + 1) % 10 == 0:
            print('Saving state, epoch:', epoch + 1)
            torch.save(
                model.state_dict(),
                os.path.join(path_to_save,
                             args.version + '_' + repr(epoch + 1) + '.pth'))